コード例 #1
0
def model(mtf_features,
          other_features,
          params,
          mesh,
          variable_dtype,
          context=None):
    """A GPT style model implemented in mesh tensorflow."""

    x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(
        mtf_features, other_features)

    if is_incremental_inference(context):
        # reshape inputs if in inference mode
        x = mtf.gather(x, context.position - 1, sequence_dim)
        x = mtf.reshape(x, [batch_dim])

    use_axial_pos_emb = params["axial_pos_emb"] is not None

    if not use_axial_pos_emb:
        # Use standard position encoding
        wpe = mtf.get_variable(
            mesh,
            "wpe",
            mtf.Shape([embed_sequence_dim, embd_dim]),
            initializer=tf.random_normal_initializer(stddev=0.01),
            master_dtype=variable_dtype.master_dtype,
            slice_dtype=variable_dtype.slice_dtype,
            activation_dtype=variable_dtype.activation_dtype)
    else:
        wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)

    # Text encoding
    wte = mtf.get_variable(
        mesh,
        "wte",
        mtf.Shape([vocab_dim, embd_dim]),
        initializer=tf.random_normal_initializer(stddev=0.02),
        master_dtype=variable_dtype.master_dtype,
        slice_dtype=variable_dtype.slice_dtype,
        activation_dtype=variable_dtype.activation_dtype)

    with tf.variable_scope("token_embd"):
        # Text embedding
        h = mtf.gather(wte, x, vocab_dim)
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            h = mtf.dropout(h,
                            rate=params["embed_dropout"],
                            name="wte_dropout")

    with tf.variable_scope("pos_embd"):
        # Positional embedding
        position_indices = mtf.range(
            mesh, sequence_dim,
            tf.int64) if not is_incremental_inference(context) else (
                context.position - 1)
        pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            pos_emb = mtf.dropout(pos_emb,
                                  rate=params["embed_dropout"],
                                  name="wte_dropout")
        h += pos_emb

    aux_losses = 0  # instantiate auxiliary losses (for MOE models)

    for layer in range(params["n_layer"]):
        # attn blocks
        share_parameters = exists(
            params["share_parameters"]) and params["share_parameters"] == True
        block_scope = f"h{layer}" if not share_parameters else ""

        block_fn = block(params=params,
                         scope=block_scope,
                         layer_num=layer,
                         bias=other_features["attn_bias"],
                         sequence_dim=sequence_dim,
                         memory_length_dim=other_features["memory_length_dim"],
                         variable_dtype=variable_dtype,
                         context=context)

        # If true and in train mode, enable gradient checkpointing
        recompute_grad = params["recompute_grad"] and (params["mode"]
                                                       == "train") == True
        h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(
            block_fn, [h])
        aux_losses += loss

    no_weight_tie_emb = params["no_weight_tie"] == True
    if no_weight_tie_emb:
        with tf.variable_scope("wte_final_linear"):
            logits = linear(h,
                            "linear_out",
                            vocab_dim,
                            variable_dtype=variable_dtype,
                            params=params)
    else:
        # Layer normalize & affine transform
        h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
        seq_dim = sequence_dim if not is_incremental_inference(
            context) else mtf.Dimension("sequence", 1)
        with tf.variable_scope("wte_final_einsum"):
            # Equivalent to tf.matmul
            logits = mtf.einsum([h, wte],
                                output_shape=[batch_dim, seq_dim, vocab_dim])

    if params["mode"] in ["train", "eval"]:
        labels = mtf_features["labels"]
        z_loss = params.get(
            "z_loss", 1e-4)  # an auxiliary loss used to stabilize mtf xentropy

        # Go to full precision for the logits
        logits = mtf.cast(logits, tf.float32)

        use_entmax_loss = params.get("entmax_loss", False)
        loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits

        with tf.variable_scope("xentropy_final"):
            loss_batch = loss_fn(logits=logits,
                                 targets=labels,
                                 vocab_dim=logits.shape[-1],
                                 z_loss=z_loss)

        # For non-autoregressive models (masked language modeling training)
        # Make sure labels with padding tokens are not counted in the loss
        if not params["causal"]:
            padding_id = params.get("padding_id", 0)
            loss_batch = mtf.where(mtf.not_equal(labels, padding_id),
                                   loss_batch, mtf.zeros_like(loss_batch))

        with tf.variable_scope("reduce_mean_final"):
            loss = mtf.reduce_mean(loss_batch)

        loss += aux_losses  # Add on auxiliary losses (currently only used for MoE)
        loss /= params["num_microbatches"]
        # Convert to train dtype
        loss = mtf.cast(loss, variable_dtype.slice_dtype)
    else:
        loss = None
        loss_batch = None

    # Cast back to checkpoint dtype
    logits = mtf.cast(logits, variable_dtype.master_dtype)
    return logits, loss, loss_batch
コード例 #2
0
 def value_dim(self):
     """Dimensionality of attention value."""
     if self.config.attention_value_head_size is None:
         raise ValueError("The value head size is not defined.")
     return mtf.Dimension("d_v", self.config.attention_value_head_size)
コード例 #3
0
def transformer_moe_layer_v1(
    inputs, output_dim, hparams, train, variable_dtype,
    layout=None, mesh_shape=None, nonpadding=None):
  """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [<batch_dims>, length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).

  Returns:
    outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
  orig_inputs = inputs
  hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
  experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups is a multiple of the mesh dimension
  # over which those groups are split.
  orig_batch_dim, orig_length_dim, input_dim = orig_inputs.shape.dims

  num_groups, group_size = _split_into_groups(
      orig_batch_dim.size * orig_length_dim.size, hparams.moe_group_size,
      mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, orig_batch_dim))
  group_size_dim = mtf.Dimension("group", group_size)
  batch_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
  inputs = mtf.reshape(inputs, [batch_dim, group_size_dim, input_dim])

  # Each sequence sends expert_capacity positions to each expert.
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(
      group_size_dim.size,
      int((group_size_dim.size * capacity_factor) / experts_dim.size))
  expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

  experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
  batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)
  if nonpadding is not None:
    nonpadding = mtf.zeros(inputs.mesh, [orig_batch_dim, orig_length_dim],
                           dtype=inputs.dtype) + nonpadding
    nonpadding = mtf.reshape(nonpadding, [batch_dim, group_size_dim])
  if hparams.moe_gating == "top_2":
    dispatch_tensor, combine_tensor, loss = _top_2_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=experts_dim_unsplit,
        expert_capacity_dim=expert_capacity_dim,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=nonpadding)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  # put num_experts dimension first to make split easier in alltoall
  expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape(
      [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

  expert_inputs = mtf.reshape(expert_inputs, mtf.Shape(
      [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

  # Now feed the expert inputs through the experts.
  h = mtf.layers.dense(
      expert_inputs, hidden_dim, expert_dims=[experts_dim],
      activation=mtf.relu, use_bias=False,
      variable_dtype=variable_dtype, name="wi")
  expert_output = mtf.layers.dense(
      h, output_dim, expert_dims=[experts_dim], use_bias=False,
      variable_dtype=variable_dtype,
      name="wo")

  expert_output = mtf.reshape(expert_output, mtf.Shape(
      [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

  output = mtf.einsum([expert_output, combine_tensor], mtf.Shape(
      [batch_dim, group_size_dim, output_dim]))

  output = mtf.reshape(output, orig_inputs.shape.dims[:-1] + [output_dim])

  return output, loss * hparams.moe_loss_coef
コード例 #4
0
 def feedforward_intermediate_dim(self):
     return mtf.Dimension("intermediate",
                          self.config.feedforward_intermediate_size)
コード例 #5
0
 def max_position_embeddings_dim(self):
     return mtf.Dimension("max_position_embeddings",
                          self.config.max_position_embeddings)
コード例 #6
0
    def __init__(self,
                 config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 scope=None,
                 mesh_shape="",
                 layout=""):
        self.config = copy.deepcopy(config)
        del config
        if not is_training:
            self.config.layer_output_dropout_prob = 0.0
            self.config.attention_probs_dropout_prob = 0.0
            self.config.feedforward_intermediate_dropout_prob = 0.0
        input_shape = input_ids.shape
        assert input_shape.ndims == 2

        self._seq_dim = input_shape.dims[1]
        self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
        self._extra_losses = []
        mesh = input_ids.mesh

        if token_type_ids is None:
            token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

        with tf.variable_scope(scope, default_name="bert"):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                self.embedding_table = mtf.get_variable(
                    mesh,
                    "word_embeddings",
                    mtf.Shape([self.vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                self.word_embedding_output = mtf.gather(
                    self.embedding_table, input_ids, self.vocab_dim)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = self.word_embedding_output

                token_type_table = mtf.get_variable(
                    mesh,
                    "token_type_embeddings",
                    mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                if token_type_ids is not None:
                    self.embedding_output += mtf.gather(
                        token_type_table, token_type_ids,
                        self.token_type_vocab_dim)
                if self.config.position_signal == "embedding":
                    full_position_table = mtf.get_variable(
                        mesh,
                        "position_embeddings",
                        mtf.Shape(
                            [self.max_position_embeddings_dim,
                             self.model_dim]),
                        initializer=self.embedding_initializer)
                    short_position_table = mtf.rename_dimension(
                        mtf.slice(full_position_table, 0, self.seq_dim.size,
                                  self.max_position_embeddings_dim.name),
                        self.max_position_embeddings_dim.name,
                        self.seq_dim.name)
                    self.embedding_output += short_position_table
                self.embedding_output = self.normalize(self.embedding_output)
                self.embedding_output = mtf.dropout(
                    self.embedding_output,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)

            with tf.variable_scope("encoder"):
                attention_biases = []
                if input_mask:
                    # [batch_dim, memory_seq_dim]
                    attention_biases.append((1.0 - mtf.to_float(
                        mtf.replace_dimensions(input_mask, self.seq_dim,
                                               self.memory_seq_dim))) *
                                            -10000.0)
                if self.config.position_signal == "relative_attention_bias":
                    buckets_dim = mtf.Dimension("buckets", 32)
                    rp_bucket = _relative_position_bucket(
                        mtf.range(mesh, self.memory_seq_dim, tf.int32) -
                        mtf.range(mesh, self.seq_dim, tf.int32),
                        num_buckets=buckets_dim.size)
                    bias_var = mtf.get_variable(
                        mesh,
                        "relative_attention_bias",
                        [self.num_heads_dim, buckets_dim],
                        initializer=tf.zeros_initializer())
                    attention_biases.append(
                        mtf.gather(bias_var, rp_bucket, buckets_dim))
                attention_bias = mtf.add_n(attention_biases)
                prev_layer_output = self.embedding_output
                self.all_encoder_layers = []
                for block_num in range(self.config.num_blocks):
                    with tf.variable_scope("block_%d" % block_num):
                        for layer_idx, layer_type in enumerate(
                                self.config.block_layers):
                            layer_name = layer_type
                            count = self.config.block_layers[:layer_idx].count(
                                layer_type)
                            if count:
                                layer_name += "_%d" % count
                            with tf.variable_scope(layer_name):
                                x = prev_layer_output
                                if self.config.residual_structure == "direct":
                                    x = self.normalize(x)
                                if layer_type == "attention":
                                    x = self.self_attention(x, attention_bias)
                                elif layer_type == "feedforward":
                                    x = self.feedforward(x)
                                elif layer_type == "moe":
                                    x = self.moe(x, layout, mesh_shape,
                                                 input_mask, is_training)
                                else:
                                    raise ValueError("unknown layer type " +
                                                     layer_type)
                                x = mtf.dropout(
                                    x,
                                    keep_prob=1.0 -
                                    self.config.layer_output_dropout_prob)
                                layer_output = prev_layer_output + x
                                if self.config.residual_structure == "original":
                                    layer_output = self.normalize(layer_output)
                                prev_layer_output = layer_output
                    self.all_encoder_layers.append(layer_output)

            self.sequence_output = prev_layer_output
            if self.config.residual_structure == "direct":
                self.sequence_output = self.normalize(self.sequence_output)

            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_dim, seq_dim, hidden_size] to a tensor of shape
            # [batch_dim, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = mtf.gather(self.sequence_output, 0,
                                                self.seq_dim)
                tf.logging.info(
                    f"[haiqwa-test] first_token_tensor shape: {first_token_tensor.shape}"
                )
                self.pooled_output = mtf.layers.dense(
                    first_token_tensor,
                    reduced_dims=[self.model_dim],
                    new_dims=[self.model_dim],
                    activation=mtf.tanh,
                    kernel_initializer=self.dense_initializer,
                    use_bias=self.config.use_bias)
コード例 #7
0
 def model_dim(self):
     return mtf.Dimension("hidden", self.config.d_model)
コード例 #8
0
 def testConvertToShape(self, inputs):
     shape = mtf.convert_to_shape(inputs)
     self.assertEqual(
         shape, mtf.Shape([mtf.Dimension("x", 4),
                           mtf.Dimension("y", 8)]))
コード例 #9
0
ファイル: run_classifier.py プロジェクト: bruinxiong/mesh-1
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    mesh_devices,
                                                    ctx.device_assignment)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)
        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        num_labels_dim = mtf.Dimension("seq", num_labels)
        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, mtf_input_ids,
                                       mtf_input_mask, mtf_segment_ids,
                                       mtf_label_ids, num_labels_dim,
                                       layout_rules, mesh_shape)
        total_loss = mtf.anonymize(total_loss)
        per_example_loss = mtf.anonymize(per_example_loss)
        logits = mtf.anonymize(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                max_optimized_variable_size=FLAGS.max_optimized_variable_size,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(per_example_loss), label_ids,
                lowering.export_to_tf_tensor(logits), is_real_example
            ])

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = bert_lib.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook],
                    scaffold_fn=scaffold_fn)
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    prediction_hooks=[restore_hook],
                    predictions={
                        "probabilities":
                        lowering.export_to_tf_tensor(probabilities)
                    },
                    scaffold_fn=scaffold_fn)
コード例 #10
0
def Alexnet(img, labels, num_nodes, num_gpus, args):
    num_classes = 1000
    keep_prob = 0.5
    learning_rate = 0.01
    graph, meshes, mesh_to_impl, mtf_img, mtf_labels = CreateMeshes(
        img, labels, num_nodes, num_gpus, args)
    RenameFC = lambda x: mt.rename_dimension(x, x.shape[-1].name,
                                             utils.RandName())

    strategy = args.strategy
    if strategy == 0:
        fc6_units = mtf.Dimension(utils.RandName(), 4096)
        fc7_units = mtf.Dimension(utils.RandName(), 4096)
        fc8_units = mtf.Dimension(utils.RandName(), num_classes)

    elif strategy == 1:
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    elif strategy == 2:
        num_classes = utils.RoundUp(num_classes, num_gpus)
        fc6_units = mtf.Dimension('axis0', 4096)
        fc7_units = mtf.Dimension('axis0', 4096)
        fc8_units = mtf.Dimension('axis0', num_classes)

    elif strategy == 3:
        num_classes = utils.RoundUp(num_classes, num_gpus // 2)
        fc6_units = mtf.Dimension('axis1', 4096)
        fc7_units = mtf.Dimension('axis1', 4096)
        fc8_units = mtf.Dimension('axis1', num_classes)

    with tf.variable_scope('alexnet'):
        # Conv1 + ReLU + maxpool1
        conv1 = mt.Conv2d(mtf_img,
                          GetFilterShape(mtf_img, (11, 11, 3, 96)), (4, 4),
                          'VALID',
                          activation=mtf.relu,
                          name='conv1')
        pool1 = mt.MaxPool(conv1, (3, 3), (2, 2), 'VALID', name='pool1')

        # Conv2 + ReLU + maxpool2
        conv2 = mt.Conv2d(pool1,
                          GetFilterShape(pool1, (5, 5, 96, 256)), (1, 1),
                          'SAME',
                          activation=mtf.relu,
                          name='conv2')
        pool2 = mt.MaxPool(conv2, (3, 3), (2, 2), name='pool2')

        # Conv3 + ReLU
        conv3 = mt.Conv2d(pool2,
                          GetFilterShape(pool2, (3, 3, 256, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv3')

        # Conv4 + ReLU
        conv4 = mt.Conv2d(conv3,
                          GetFilterShape(conv3, (3, 3, 384, 384)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv4')

        # Conv5 + ReLU + maxpool5
        conv5 = mt.Conv2d(conv4,
                          GetFilterShape(conv4, (3, 3, 384, 256)),
                          padding='SAME',
                          activation=mtf.relu,
                          name='conv5')
        pool5 = mt.MaxPool(conv5, (3, 3), (2, 2), name='pool5')

        # Rename dims
        if strategy == 1:
            k_dim = mtf.Dimension(utils.RandName(),
                                  utils.Prod(pool5.shape.to_integer_list[1:]))
            pool5 = mtf.reshape(pool5, mtf.Shape([pool5.shape[0], k_dim]))
            pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1],
                                                   (utils.RandName(), 'axis0'))

        elif strategy == 2:
            pool5 = mt.rename_dimension(pool5, pool5.shape[0].name,
                                        utils.RandName())

        elif strategy == 3:
            assert pool5.shape[0].name == 'axis0'
            #dim_names = pool5.shape.rename_dimension('axis0', utils.RandName())
            #pool5 = ReplaceMeshWithIndependentAxes(pool5, meshes[1], dim_names)
            pool5 = ReplaceMeshWithConcatSplit(pool5, meshes[1])

        # FC + ReLU + dropout
        fc_activation = lambda x: mtf.dropout(mtf.relu(x), keep_prob)
        fc6 = mtf.layers.dense(pool5,
                               fc6_units,
                               activation=fc_activation,
                               reduced_dims=pool5.shape[1:],
                               name='fc6')
        if strategy == 2:
            fc6 = RenameFC(fc6)
        elif strategy == 3:
            fc6 = RenameFC(fc6)

        fc7 = mtf.layers.dense(fc6,
                               fc7_units,
                               activation=fc_activation,
                               reduced_dims=fc6.shape.dims[-1:],
                               name='fc7')
        if strategy == 2:
            fc7 = RenameFC(fc7)
        elif strategy == 3:
            fc7 = RenameFC(fc7)

        fc8 = mtf.layers.dense(fc7,
                               fc8_units,
                               reduced_dims=fc7.shape.dims[-1:],
                               name='fc8')
        fc8 = mtf.dropout(fc8, keep_prob)

        if strategy == 1:
            assert fc8.shape[-1].name == 'axis1'
            fc8 = ReplaceMeshWithDuplicates(fc8, meshes[2])

    with tf.variable_scope('loss'):
        if fc8.shape[0] != mtf_labels.shape[0]:
            fc8 = mt.rename_dimension(fc8, fc8.shape[0].name,
                                      mtf_labels.shape[0].name)
        one_hot_labels = mtf.one_hot(mtf_labels, fc8.shape[-1])
        mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
            fc8, one_hot_labels, fc8.shape[-1])
        mtf_loss = mtf.reduce_mean(mtf_cross_ent)

    return graph, mesh_to_impl, mtf_loss
コード例 #11
0
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters(
        (mtf.Dimension("x", 5), ),
        (("x", 5), ),
    )
    def testConvertToDimension(self, inputs):
        dimension = mtf.convert_to_dimension(inputs)
        self.assertEqual(dimension.name, "x")
        self.assertEqual(dimension.size, 5)

    def testConvertToDimensionGenericInputs(self):
        dimension = mtf.convert_to_dimension(None)
        self.assertEqual(dimension, None)
        with self.assertRaises(TypeError):
            mtf.convert_to_dimension(5)

    @parameterized.parameters(
        (mtf.Shape([mtf.Dimension("x", 4),
                    mtf.Dimension("y", 8)]), ),
        ("x:4;y:8", ),
        ("x:4.y:8", ),
        ("x:4 y:8", ),
        ("x:4,y:8", ),
    )
    def testConvertToShape(self, inputs):
        shape = mtf.convert_to_shape(inputs)
        self.assertEqual(
            shape, mtf.Shape([mtf.Dimension("x", 4),
                              mtf.Dimension("y", 8)]))

    def testConvertToShapeGenericInputs(self):
        shape = mtf.convert_to_shape([])
        self.assertEqual(shape.dims, [])
        shape = mtf.convert_to_shape(None)
        self.assertEqual(shape, None)
        with self.assertRaises(ValueError):
            mtf.convert_to_shape("x;4")

    @parameterized.parameters(
        (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ),
        ("d_ff:model;heads:model", ),
        ("d_ff:model.heads:model", ),
        ("d_ff:model heads:model", ),
        ("d_ff:model,heads:model", ),
        ([("d_ff", "model"), ("heads", "model")], ),
    )
    def testConvertToLayoutRules(self, inputs):
        layout_rules = mtf.convert_to_layout_rules(inputs)
        self.assertEqual(
            layout_rules._pairs,
            mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

    def testConvertToLayoutRulesGenericInputs(self):
        with self.assertRaises(ValueError):
            mtf.convert_to_layout_rules("d_ff;heads")

    def testTensorLayout(self):
        tensor_layout = mtf.TensorLayout([0, 2, 1])
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, ))
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
        tensor_layout = mtf.TensorLayout([None, 0])
        self.assertFalse(tensor_layout.is_fully_replicated)
        tensor_layout = mtf.TensorLayout([None, None, None])
        self.assertTrue(tensor_layout.is_fully_replicated)

    def testGraph(self):
        graph = mtf.Graph()
        self.assertEmpty(graph.operations)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        mesh = mtf.Mesh(graph, "mesh_test")
        _ = mtf.import_tf_tensor(mesh,
                                 tf_tensor=tf.constant(0.),
                                 shape=mtf.Shape([]))
        self.assertLen(graph.operations, 1)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
        self.assertLen(graph.operations, 2)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 1)
        _ = mtf.get_variable(mesh,
                             "variable_1",
                             mtf.Shape([]),
                             trainable=False)
        self.assertLen(graph.operations, 3)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 2)

    def testGraphNames(self):
        # Standard Usage.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Edge cases, the user may choose the name "a_1".
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a_1"), "a_1_1")

        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a_1"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Case insensitive.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("A"), "A_1")

    @tf.contrib.eager.run_test_in_graph_and_eager_modes()
    def testLowering(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        inputs = tf.constant(0.)
        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          tf_tensor=inputs,
                                          shape=mtf.Shape([]))
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        outputs = lowering.export_to_tf_tensor(mtf_inputs)
        inputs_value, outputs_value = self.evaluate([inputs, outputs])
        self.assertEqual(inputs_value, outputs_value)

        # Check that methods run without error.
        _ = lowering.copy_masters_to_slices()
        _ = lowering.copy_slices_to_masters()

    def testMesh(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        self.assertEqual(mesh.graph, graph)

    def testMeshImpl(self):
        shape = mtf.Shape(
            [mtf.Dimension("batch", 4),
             mtf.Dimension("model", 8)])
        layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"),
                                        ("heads", "model")])
        mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
        self.assertEqual(mesh_impl.shape, shape)
        self.assertLen(shape, mesh_impl.ndims)
        self.assertEqual(mesh_impl.layout_rules, layout_rules)
        self.assertEqual(mesh_impl.size, shape.size)
        self.assertTrue(mesh_impl.supports_control_dependencies)

        batch = mtf.Dimension("batch", 128)
        length = mtf.Dimension("length", 500)
        d_ff = mtf.Dimension("d_ff", 2048)
        heads = mtf.Dimension("heads", 8)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
        self.assertEqual(
            mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
            mtf.TensorLayout([0, None, 1]))
コード例 #12
0
def recon_model(mesh,
                data,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))
    print("\nfieldvar : \n", fieldvar)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            fieldvar,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #*nc**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.constant(R0)
    print("R0 in the recon_model : ", R0)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    fields = [fieldvar, final_field]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
コード例 #13
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    # MTF setup.
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info("device_list = %s" % device_list,)
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                mesh_devices,
                                                ctx.device_assignment)
    mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    batch_size = input_ids.get_shape()[0].value
    batch_dim = mtf.Dimension("batch", batch_size)
    seq_length = input_ids.get_shape()[1].value
    seq_dim = mtf.Dimension("seq", seq_length)

    mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
    mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                          [batch_dim, seq_dim])
    mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                           [batch_dim, seq_dim])

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=mtf_input_ids,
        input_mask=mtf_input_mask,
        segment_ids=mtf_segment_ids)

    if mode == tf.estimator.ModeKeys.TRAIN:

      def compute_loss(logits, positions):
        one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim)
        log_probs = mtf.log_softmax(logits, seq_dim)
        loss = -mtf.reduce_mean(
            mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim))
        return loss

      start_positions = features["start_positions"]
      mtf_start_positions = mtf.import_tf_tensor(mesh, start_positions,
                                                 [batch_dim])
      end_positions = features["end_positions"]
      mtf_end_positions = mtf.import_tf_tensor(mesh, end_positions, [batch_dim])

      start_loss = compute_loss(start_logits, mtf_start_positions)
      end_loss = compute_loss(end_logits, mtf_end_positions)

      total_loss = (start_loss + end_loss) / 2.0
      _, update_ops = optimization_lib.create_optimizer(
          total_loss,
          learning_rate,
          num_train_steps,
          num_warmup_steps,
          max_optimized_variable_size=FLAGS.max_optimized_variable_size,
          optimizer=FLAGS.optimizer,
          clip_gradients=FLAGS.clip_gradients)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      start_logits = mtf.anonymize(start_logits)
      end_logits = mtf.anonymize(end_logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
      global_step = tf.train.get_global_step()
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      if mode == tf.estimator.ModeKeys.TRAIN:
        saver = tf.train.Saver(
            tf.global_variables(),
            sharded=True,
            max_to_keep=10,
            keep_checkpoint_every_n_hours=2,
            defer_build=False,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(
            FLAGS.output_dir,
            save_steps=1000,
            saver=saver,
            listeners=[saver_listener])

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook],
            scaffold_fn=scaffold_fn)
      elif mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "unique_ids": unique_ids,
            "start_logits": lowering.export_to_tf_tensor(start_logits),
            "end_logits": lowering.export_to_tf_tensor(end_logits),
        }

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            prediction_hooks=[restore_hook],
            predictions=predictions,
            scaffold_fn=scaffold_fn)
      else:
        raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                         (mode))
コード例 #14
0
def synthetic_attention(q,
                        k,
                        v,
                        memory_length_dim,
                        key_dim,
                        value_dim,
                        bias=None,
                        dropout_rate=0.0,
                        dropout_broadcast_dims=None,
                        extra_logit=None,
                        synthesize=True,
                        synthesize_mode="random_plus_alpha",
                        factorized_dim=16,
                        max_length=512,
                        context=None):
  """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743).

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    synthesize: flag to use synthetic attention or not
    synthesize_mode: which variant of synthesizer to use
    factorized_dim: factorized dim for synthesizers
    max_length: max length of input sequence
    context: context since we need context mode

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """

  if synthesize:
    num_heads = v.shape.get_dim_by_name("heads")
    tf.logging.info("Using synthesizer")
    if synthesize_mode == "random":
      tf.logging.info("Using Random Synthesizers")
      r_shape = mtf.Shape([mtf.Dimension("length", max_length),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length",
                                         num_heads, max_length)])
      initializer = tf.random_uniform_initializer()
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      logits = r
      r_shape = logits.shape
    elif synthesize_mode == "factorized":
      tf.logging.info("Using Factorized Random Synthesizers")
      k = factorized_dim
      r1_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r2_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      initializer = tf.random_normal_initializer()
      r1 = mtf.get_variable(context.mesh, "R1", r1_shape,
                            initializer=initializer,
                            dtype=context.variable_dtype)
      r2 = mtf.get_variable(context.mesh, "R2", r2_shape,
                            initializer=initializer,
                            dtype=context.variable_dtype)
      r = mtf.einsum([r1, r2], r_shape)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      r = mtf.slice(r, 0, length_dim.size, length_dim.name)
      logits = r
    elif synthesize_mode == "dense_minus":
      # Dense Synthesizer Model
      tmp_dim = mtf.Dimension("memory_length", max_length)
      logits = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                use_bias=False,
                                name="pi",
                                reduced_dims=[key_dim],
                                variable_dtype=None)
      logits = mtf.slice(logits, 0, memory_length_dim.size,
                         memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        logits = mtf.slice(logits, 0, length_dim.size, "length")
    elif synthesize_mode == "random_plus_alpha":
      # Mixture Random Synthesizer with learnable Alpha
      tf.logging.info("Using Random Plus Alpha")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      num_heads = logits.shape.get_dim_by_name("heads")
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, length_dim.name)
      alpha = mtf.get_variable(context.mesh,
                               "alpha",
                               mtf.Shape([mtf.Dimension("alpha", 1)]),
                               initializer=tf.zeros_initializer(),
                               dtype=context.variable_dtype)
      alpha = mtf.sigmoid(alpha)
      logits = ((1-alpha) * logits) + (alpha * r)
    elif synthesize_mode == "dense_plus_alpha":
      # Mixture Dense Synthesizer with learnable alpha
      tf.logging.info("Using Dense Plus Alpha Scaling")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      tmp_dim = mtf.Dimension("memory_length", 512)
      r = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                           use_bias=False,
                           name="pi",
                           reduced_dims=[key_dim],
                           variable_dtype=None)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      alpha = mtf.get_variable(context.mesh,
                               "alpha",
                               mtf.Shape([mtf.Dimension("alpha", 1)]),
                               initializer=tf.zeros_initializer(),
                               dtype=context.variable_dtype)
      alpha = mtf.sigmoid(alpha)
      logits = ((1-alpha) * logits) + (alpha * r)
  if bias is not None:
    logits += bias

  weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
  if dropout_rate != 0.0:
    weights = mtf.dropout(
        weights, 1.0 - dropout_rate,
        noise_shape=weights.shape - dropout_broadcast_dims)

  if synthesize and "plus" not in synthesize_mode:
    if synthesize_mode == "dense_minus":
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim])
    else:
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim])
  else:
    outputs_shape = q.shape - [key_dim] + value_dim
  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs
コード例 #15
0
ファイル: cifar10.py プロジェクト: mkrdip/alcf
def cifar_model(features, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 32*32]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    features = copy.copy(features)
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 8)
    cols_dim = mtf.Dimension("cols_size", 8)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 3)

    image = features['image']
    image = bnorm(image)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # Add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 7)
    fw_dim = mtf.Dimension("fw", 7)
    filters1_dim = mtf.Dimension("filters1", 32)
    filters2_dim = mtf.Dimension("filters2", 32)

    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))

    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))

    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # Add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h2, classes_dim, name="logits")

    labels = features['label']
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)

    return logits, loss
コード例 #16
0
    def testPool(self, pooling_method):
        batch = 2
        depth = 3
        height = 4
        width = 6
        channels = 3
        tf.random.set_random_seed(1234)
        inputs = tf.random_normal([batch, depth, height, width, channels])

        stride_d = 3
        stride_h = 2
        stride_w = 3

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        depth_dim = mtf.Dimension("depth", depth)
        height_dim = mtf.Dimension("height", height)
        width_dim = mtf.Dimension("width", width)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape([
                                              batch_dim, depth_dim, height_dim,
                                              width_dim, channels_dim
                                          ]))

        if pooling_method == "MAX_2D":
            mtf_outputs = mtf.layers.max_pool2d(mtf_inputs,
                                                ksize=(stride_h, stride_w))
            inputs = tf.reshape(inputs,
                                [batch * depth, height, width, channels])
            expected_outputs = tf.keras.layers.MaxPooling2D(
                (stride_h, stride_w))(inputs)
            expected_outputs = tf.reshape(expected_outputs, [
                batch, depth,
                int(height / stride_h),
                int(width / stride_w), channels
            ])

        elif pooling_method == "AVG_2D":
            mtf_outputs = mtf.layers.avg_pool2d(mtf_inputs,
                                                ksize=(stride_h, stride_w))
            inputs = tf.reshape(inputs,
                                [batch * depth, height, width, channels])
            expected_outputs = tf.keras.layers.AveragePooling2D(
                (stride_h, stride_w))(inputs)
            expected_outputs = tf.reshape(expected_outputs, [
                batch, depth,
                int(height / stride_h),
                int(width / stride_w), channels
            ])

        elif pooling_method == "MAX_3D":
            mtf_outputs = mtf.layers.max_pool3d(
                mtf_inputs, ksize=[stride_d, stride_h, stride_w])
            expected_outputs = tf.keras.layers.MaxPooling3D(
                [stride_d, stride_h, stride_w])(inputs)

        elif pooling_method == "AVG_3D":
            mtf_outputs = mtf.layers.avg_pool3d(
                mtf_inputs, ksize=[stride_d, stride_h, stride_w])
            expected_outputs = tf.keras.layers.AveragePooling3D(
                [stride_d, stride_h, stride_w])(inputs)

        mtf_gradient = mtf.gradients([mtf_outputs], [mtf_inputs])[0]

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
        actual_gradient = lowering.export_to_tf_tensor(mtf_gradient)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])
        self.assertAllClose(actual, expected)

        actual = self.evaluate(actual_gradient)
        if pooling_method == "MAX_2D":
            expected_non_zeros = batch * depth * height * width * channels / (
                stride_h * stride_w)
            self.assertEqual(np.count_nonzero(actual), expected_non_zeros)

        elif pooling_method == "AVG_2D":
            expected = np.ones((batch, depth, height, width, channels),
                               dtype=np.float32) / stride_h / stride_w
            self.assertAllClose(actual, expected)

        elif pooling_method == "MAX_3D":
            expected_non_zeros = batch * depth * height * width * channels / (
                stride_d * stride_h * stride_w)
            self.assertEqual(np.count_nonzero(actual), expected_non_zeros)

        elif pooling_method == "AVG_3D":
            expected = np.ones(
                (batch, depth, height, width, channels),
                dtype=np.float32) / stride_d / stride_h / stride_w
            self.assertAllClose(actual, expected)
コード例 #17
0
ファイル: input_reader_test.py プロジェクト: yaohuaxin/mesh
    def test_get_laidout_tensors(self, is_eval_mode):
        mesh_shape = "mesh_x:2, mesh_y:1"
        layout = "batch:mesh_x, io:mesh_y"
        batch_io_dim = 4

        with tf.Session() as sess:
            topology, num_cores = self.initialize_system(sess)

            # Get a device_assignment object for mtf.
            d_assignment = device_assignment.device_assignment(
                topology, computation_shape=[1, 1, 1], num_replicas=num_cores)

            # Hacked dataset creator: creates different datasets for the first and
            # second call, in order to test SimdMeshImplInputReader.
            self.sub_batch_created_times = 0

            def stateful_ds_creator():
                whole_batch = tf.eye(batch_io_dim, dtype=tf.float32)
                sub_batch = tf.slice(whole_batch,
                                     [self.sub_batch_created_times * 2, 0],
                                     [2, 4])
                self.sub_batch_created_times += 1
                return tf.data.Dataset.from_tensors(
                    sub_batch).repeat().unbatch()

            batch_dim = mtf.Dimension("batch", batch_io_dim)
            io_dim = mtf.Dimension("io", batch_io_dim)
            mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])]

            # Get mesh_impl.
            mesh_shape = mtf.convert_to_shape(mesh_shape)
            layout_rules = mtf.convert_to_layout_rules(layout)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, None, d_assignment)

            simd_input_reader = input_reader.SimdMeshImplInputReader(
                mesh_impl,
                stateful_ds_creator,
                mtf_input_shapes,
                external_worker=False,
                is_eval_mode=is_eval_mode)

            def model_fn(features):
                return features

            replicated_computation = tpu.replicate(
                computation=model_fn,
                inputs=[[]] * num_cores,
                infeed_queue=simd_input_reader.infeed_queue,
                device_assignment=d_assignment)

            simd_input_reader.start_infeed_thread(sess, 1)
            results = sess.run(replicated_computation)
            print("results: {}".format(results))

            core_0_data = results[0][0]
            core_1_data = results[1][0]
            print("core_0_data: {}".format(core_0_data))
            print("core_1_data: {}".format(core_1_data))

            if is_eval_mode:
                # If there is only one dataset object, then the stateful_ds_creator()
                # should be called only once.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_1_data)
            else:
                # If there are two dataset objects, then the stateful_ds_creator()
                # should be called twice.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32),
                    core_1_data)

            sess.run(tf.tpu.shutdown_system())
コード例 #18
0
ファイル: attention.py プロジェクト: manoelakohler/mesh
def _maybe_reshape_attention_input_for_2d_sharding(context, q, k, v, bias,
                                                   unsplittable_dims):
    """Reshape the inputs to attention to split over an unused mesh dimension.

  In the case where the attention computation is unnecessarily replicated,
  this function reshapes the attention inputs to remove the unnecessary
  replication.

  This becomes relevent when doing 2-dimenional model parallelism.
  d_model is sharded over one mesh dimension and [vocab, num_heads, d_ff] are
  sharded over the other mesh dimension.  This fully distributes all of the
  einsum operations, except for the internals of the attention computation.

  To distribute that computation, this function creates a new tensor-dimension
  from the low bits of either the batch dimension or the num_heads dimension,
  and then splits that dimension over the unused mesh dimension.

  Args:
    context: a transformer.Context
    q: a Tensor
    k: a Tensor
    v: a Tensor
    bias: a Tensor
    unsplittable_dims: a list of tensor-dimensions not to split.  The key/value
      dimensions should be passed here.
  Returns:
    reshaped_q: a Tensor
    reshaped_k: a Tensor
    reshaped_v: a Tensor
    reshaped_bias: a Tensor
  """
    original_inputs = q, k, v, bias
    # we need to know the layout and mesh-shape to figure out what to do.
    if not context or not context.model.layout or not context.model.mesh_shape:
        return original_inputs
    mesh_shape = mtf.convert_to_shape(context.model.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(context.model.layout)
    # find a mesh dim that is unused (no tensor-dimension is split across it)
    mesh_axis_used = [False] * mesh_shape.ndims
    for x in original_inputs:
        for mesh_axis in layout_rules.tensor_layout(
                x.shape, mesh_shape).tensor_axis_to_mesh_axis:
            if mesh_axis is not None:
                mesh_axis_used[mesh_axis] = True
    if False not in mesh_axis_used:
        return original_inputs
    mesh_dim = mesh_shape.dims[mesh_axis_used.index(False)]
    # Choose an appropriate name for the new tensor-dimension so that the layout
    #   will know to split it across the unused mesh dimension.
    tensor_dim_name = None
    tensor_dim_name = layout_rules.mesh_dimension_name_to_tensor_dimension_names(
        mesh_dim.name)
    if tensor_dim_name:
        tensor_dim_name = tensor_dim_name[0]
    else:
        return original_inputs
    # Find a tensor-dimension that we can further split, by breaking off the
    # lower bits into our new tensor-dimension.
    # This resplittable tensor-dimension must be presnent in all of q, k, v
    #   and must be large enough to be further split.
    resplittable_dim = None
    for d in q.shape.dims:
        if d in k.shape.dims and d in v.shape.dims and d not in unsplittable_dims:
            num_splits = mtf.tensor_dim_to_mesh_dim_size(
                context.model.layout, context.model.mesh_shape, d)
            if d.size % (num_splits * mesh_dim.size) == 0:
                resplittable_dim = d
                break
    if not resplittable_dim:
        return original_inputs
    new_dim_high = mtf.Dimension(resplittable_dim.name, num_splits)
    new_dim_low = mtf.Dimension(tensor_dim_name,
                                resplittable_dim.size // num_splits)

    def _my_reshape(x):
        if x and resplittable_dim in x.shape.dims:
            return mtf.replace_dimensions(x, resplittable_dim,
                                          [new_dim_high, new_dim_low])
        else:
            return x

    return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
コード例 #19
0
 def vocab_dim(self):
     # pad vocab to a multiple of 128 so as to be splittable.
     # TODO(noam): This creates issues in checkpoint compatibility
     n = self.config.vocab_size
     return mtf.Dimension("vocab", n + (-n % 128))
コード例 #20
0
ファイル: attention.py プロジェクト: manoelakohler/mesh
def hybrid_attention(q,
                     k,
                     v,
                     context,
                     memory_length_dim,
                     key_dim,
                     value_dim,
                     bias=None,
                     dropout_rate=0.0,
                     dropout_broadcast_dims=None,
                     extra_logit=None):
    """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    context: context of the attention layer.
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
    logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim])
    if bias is not None:
        logits += bias

    query_length_dim = mtf.Dimension("length", memory_length_dim.size)
    doubly_coeff = mtf.get_variable(context.mesh,
                                    "doubly_coeff", [],
                                    initializer=tf.constant_initializer(0.5),
                                    dtype=context.variable_dtype)
    doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.)

    upper_weights = mtf.softmax(logits,
                                memory_length_dim,
                                extra_logit=extra_logit)

    lower_log_weights = mtf.log_softmax(logits,
                                        query_length_dim,
                                        extra_logit=extra_logit)
    doubly_weights = mtf.softmax(lower_log_weights,
                                 memory_length_dim,
                                 extra_logit=extra_logit)

    weights = doubly_coeff * doubly_weights + (1. -
                                               doubly_coeff) * upper_weights
    if dropout_rate != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout_rate,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)
    outputs_shape = q.shape - key_dim + value_dim
    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs
コード例 #21
0
 def token_type_vocab_dim(self):
     return mtf.Dimension("token_type_vocab", self.config.type_vocab_size)
コード例 #22
0
def gradient_based_subword_tokenization(x,
                                        length_dim,
                                        max_subword_length=4,
                                        downsample=None,
                                        use_offsets=False,
                                        consider_chars_as_blocks=False,
                                        use_block_pos_embedding=False,
                                        share_block_kernel=False,
                                        memory_embeddings=0,
                                        context=None,
                                        block_mixing_mode=None,
                                        activation="softmax",
                                        downsample_function="mean"):
  """Implements GBSWT from Charformer.

  Args:
    x: a Tensor containing length_dim
    length_dim: a Dimension
    max_subword_length: integer
    downsample: integer.
    use_offsets: boolean.
    consider_chars_as_blocks: boolean.
    use_block_pos_embedding: boolean.
    share_block_kernel: boolean.
    memory_embeddings: integer.
    context: Context.
    block_mixing_mode: Str for block mixing.
    activation: Str for block ranking.
    downsample_function: Str, supports mean/linformer for now.

  Returns:
    a Tensor with the same shape as x.

  Raises:
    ValueError: if channels or depth don't match.
  """
  # don't use this for now.
  del max_subword_length
  del memory_embeddings
  all_blocks = []
  all_scores = []
  tf.logging.info("GSW block layer")

  def _tile(x, n, tile_dim):
    # Simple tile function in MTF.
    return mtf.concat([x] * n, tile_dim.name)

  def _repeat(x, n, repeat_dim):
    # repeat function in MTF
    tmp_dim = mtf.Dimension("tmp", 1)
    expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
    x = mtf.reshape(x, expand_shape)
    x = _tile(x, n, tmp_dim)
    output_shape = []
    for dim in x.shape.dims:
      if dim.name == "tmp":
        continue
      if dim.name == repeat_dim.name:
        dim = mtf.Dimension(dim.name, dim.size * n)
      output_shape.append(dim)
    output_shape = mtf.Shape(output_shape)
    x = mtf.reshape(x, output_shape)
    return x

  def _combined_dim(dims):
    return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)

  # compute all subword blocks
  # TODO(yitay): handle offsets to get all blocks
  if activation == "sigtanh":
    # one score for sigmoid
    tmp_dim = mtf.Dimension("block_score", 2)
  else:
    tmp_dim = mtf.Dimension("block_score", 1)

  model_dim = x.shape[-1]
  subword_blocks_width = [2, 3, 4]

  if consider_chars_as_blocks:
    subword_blocks_width += [1]

  if share_block_kernel:
    block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
    block_kernel = mtf.get_variable(
        x.mesh, "block_kernel", block_kernel_shape, initializer=None,
        dtype=context.variable_dtype)
  else:
    block_kernel = None

  for subword_len in subword_blocks_width:
    if use_block_pos_embedding:
      # this is turn off by default. It is meant to support cases like
      # parameterized pooling or other features.
      block_len_dim = mtf.Dimension(length_dim.name, subword_len)
      # TODO(vqtran): Consider other positional embeddings.
      block_pos_emb = sinusoid_positional_embedding_weights(
          context.mesh, block_len_dim, x.shape[-1],
          context.variable_dtype.activation_dtype)
      block_pos_emb = _repeat(block_pos_emb,
                              math.ceil(length_dim.size / float(subword_len)),
                              block_len_dim)
    if use_offsets:
      offset_space = subword_len
    else:
      offset_space = 1
    for offsets in range(offset_space):
      if offsets > 0:
        xoff = mtf.shift(x, offsets, length_dim, wrap=False)
        if use_block_pos_embedding:
          block_pos_emb = mtf.shift(
              block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False)
      else:
        xoff = x
      tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
      if length_dim.size % subword_len != 0:
        tf.logging.info("Not divisible by length")
        # add extra padding tokens
        pad_amt = int(subword_len) - int(
            length_dim.size % subword_len)
        kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
      else:
        kp = xoff

      if use_block_pos_embedding:
        kp += block_pos_emb

      bx = mtf.pool_tensor_1d(
          kp,
          pool_dim=kp.shape.get_dim_by_name("length"),
          reduce_fn=mtf.reduce_mean,
          pool_size=int(subword_len))
      block_score = mtf.layers.dense(
          bx, [tmp_dim],
          use_bias=False,
          name="bx",
          reduced_dims=[model_dim],
          variable_dtype=None,
          kernel_weights=block_kernel)

      expand_bx = _repeat(bx, subword_len, length_dim)
      expand_scores = _repeat(block_score, subword_len, length_dim)
      if offsets > 0:
        # add offset.
        expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
        expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name)
      new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
      if new_len.size < length_dim.size:
        pad_amt = new_len.size - length_dim.size
        expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
        expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name)
      elif new_len.size > length_dim.size:
        expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name)
        expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
                                  length_dim.name)

      new_tmp_dim = mtf.Dimension("extra_dim", 1)
      expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
      expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim])
      expand_bx = mtf.reshape(expand_bx, expand_shape)
      expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
      all_blocks.append(expand_bx)
      all_scores.append(expand_scores)

  all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
  all_scores = mtf.concat(all_scores, new_tmp_dim.name)
  tf.logging.info(all_blocks)
  new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
  combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
  block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
  block_net = mtf.reshape(all_scores, block_net_shape)

  if block_mixing_mode == "score_attention":
    tf.logging.info("Using score attention")
    att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
    tf.logging.info(block_net)
    att = mtf.softmax(att, reduced_dim=att.shape[-1])
    block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
    tf.logging.info(block_net)

  if activation == "softmax":
    block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
  elif activation == "tanh":
    tf.logging.info("Using tanh")
    block_net = mtf.tanh(block_net)

  all_blocks = block_net * all_blocks
  all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
  output = all_blocks

  if downsample:
    output_length = output.shape.get_dim_by_name("length")
    if output_length.size % int(downsample) != 0:
      pad_amt = int(downsample) - int(output_length.size % int(downsample))
      output = mtf.pad(output, [0, pad_amt], output_length.name)
    if downsample_function == "mean":
      output = mtf.pool_tensor_1d(
          output,
          pool_dim=output.shape.get_dim_by_name("length"),
          reduce_fn=mtf.reduce_mean,
          pool_size=int(downsample))
    else:
      raise ValueError("Downsampling function not implemeneted.")

  return output
コード例 #23
0
 def num_heads_dim(self):
     return mtf.Dimension("num_heads", self.config.attention_num_heads)
コード例 #24
0
 def _combined_dim(dims):
   return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)
コード例 #25
0
 def key_dim(self):
     """Dimensionality of attention key."""
     if self.config.attention_key_head_size is None:
         raise ValueError("The key head size is not defined.")
     return mtf.Dimension("d_k", self.config.attention_key_head_size)
コード例 #26
0
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       fully_autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       write_priority=None,
                       read_priority=None,
                       attention_kwargs=None,
                       context=None):
  """Attention to the a neighborood around the source.

  If fully_autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not fully_autoregressive, then query position p can only see memory
  positions in the range (p - window_size, p + radius].

  In addition, if write_priority and read_priority are provided, then attention
  is limited to position pairs where
  read_priority[query position] >= write_priority[memory position]

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    fully_autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    write_priority: an optional Tensor containing length_dim
    read_priority: an optional Tensor containing length_dim
    attention_kwargs: optional keyword arguments for attention()
    context: optional context.

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
  # Choose a suitable block size.
  # We choose the greatest divisor of length_per_split less than or equal
  # to max(window_size, 128)
  tf.logging.info(attention_kwargs)
  length_per_split = length_dim.size // length_dim_num_splits
  block_length = max(radius, 128)
  while length_per_split % block_length != 0:
    block_length -= 1
  query_block_length = mtf.Dimension("query_block_length", block_length)
  memory_block_length = mtf.Dimension("memory_block_length", block_length)
  # The num_blocks dimension gets the same name as the length dimension,
  # so it will be split in the same way.
  num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
  def _reshape_query(x):
    return mtf.replace_dimensions(
        x, length_dim, [num_blocks, query_block_length])
  def _reshape_memory(x):
    x = mtf.replace_dimensions(
        x, length_dim, [num_blocks, memory_block_length])
    return (mtf.left_halo_exchange if fully_autoregressive
            else mtf.halo_exchange)(
                x, num_blocks, memory_block_length, radius)
  q = _reshape_query(q)
  k = _reshape_memory(k)
  if v:
    v = _reshape_memory(v)
  else:
    v = k
  if sequence_id is None:
    sequence_id = 1
  if (not isinstance(sequence_id, mtf.Tensor) or
      length_dim not in sequence_id.shape.dims):
    sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
  q_sequence_id = _reshape_query(sequence_id)
  m_sequence_id = _reshape_memory(sequence_id)
  pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
  q_pos = _reshape_query(pos)
  m_pos = _reshape_memory(pos)

  padded_memory_block_length = mtf.Dimension(
      "memory_block_length",
      (1 if fully_autoregressive else 2) * radius + block_length)

  relative_position = m_pos - q_pos
  visible = mtf.equal(q_sequence_id, m_sequence_id)
  visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
  visible = mtf.logical_and(visible, mtf.less_equal(
      relative_position, 0 if fully_autoregressive else radius))
  if read_priority is not None:
    write_priority = _reshape_memory(write_priority)
    read_priority = _reshape_query(read_priority)
    visible = mtf.logical_and(
        visible, mtf.greater_equal(read_priority, write_priority))

  bias = attention.visibility_mask_to_attention_bias(visible, q.dtype)
  o = attention.attention(q, k, v, padded_memory_block_length, key_dim,
                          value_dim, bias, context=context,
                          **attention_kwargs)
  return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
コード例 #27
0
 def value_heads_dims(self):
     """Dimensionality of number of value heads."""
     if self.config.attention_num_value_heads is None:
         raise ValueError("The number of value heads is not defined.")
     return mtf.Dimension("value_heads",
                          self.config.attention_num_value_heads)
コード例 #28
0
def lpt_prototype(mesh,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
  """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

  klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
  plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
  ipklin = iuspline(klin, plin)
  stages = np.linspace(a0, a, nsteps, endpoint=True)

  # Define the named dimensions
  # Parameters of the small scales decomposition
  n_block_x = FLAGS.nx
  n_block_y = FLAGS.ny
  n_block_z = 1
  halo_size = FLAGS.hsize

  if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z):
    new_size = int(0.5 *
                   min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
    print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size))
    halo_size = new_size

  # Parameters of the large scales decomposition
  downsampling_factor = 0
  lnc = nc // 2**downsampling_factor

  #

  fx_dim = mtf.Dimension("nx", nc)
  fy_dim = mtf.Dimension("ny", nc)
  fz_dim = mtf.Dimension("nz", nc)

  tfx_dim = mtf.Dimension("tx", nc)
  tfy_dim = mtf.Dimension("ty", nc)
  tfz_dim = mtf.Dimension("tz", nc)

  tx_dim = mtf.Dimension("tx_lr", nc)
  ty_dim = mtf.Dimension("ty_lr", nc)
  tz_dim = mtf.Dimension("tz_lr", nc)

  nx_dim = mtf.Dimension('nx_block', n_block_x)
  ny_dim = mtf.Dimension('ny_block', n_block_y)
  nz_dim = mtf.Dimension('nz_block', n_block_z)

  sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
  sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
  sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

  k_dims = [tx_dim, ty_dim, tz_dim]

  batch_dim = mtf.Dimension("batch", batch_size)
  pk_dim = mtf.Dimension("npk", len(plin))
  pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  kx = mtf.import_tf_tensor(mesh,
                            kvec[0].squeeze().astype('float32'),
                            shape=[tfx_dim])
  ky = mtf.import_tf_tensor(mesh,
                            kvec[1].squeeze().astype('float32'),
                            shape=[tfy_dim])
  kz = mtf.import_tf_tensor(mesh,
                            kvec[2].squeeze().astype('float32'),
                            shape=[tfz_dim])
  kv = [ky, kz, kx]

  # kvec for low resolution grid
  kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
  kx_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[0].squeeze().astype('float32'),
                               shape=[tx_dim])
  ky_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[1].squeeze().astype('float32'),
                               shape=[ty_dim])
  kz_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[2].squeeze().astype('float32'),
                               shape=[tz_dim])
  kv_lr = [ky_lr, kz_lr, kx_lr]

  shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
  part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

  # Begin simulation

  initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

  #    # Reshaping array into high resolution mesh
  #    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
  #                      [initc],
  #                      output_dtype=tf.float32,
  #                      output_shape=hr_shape,
  #                      name='my_reshape',
  #                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])
  #

  state = mtfpm.lpt_init_single(
      initc,
      a0,
      kv_lr,
      halo_size,
      lr_shape,
      hr_shape,
      part_shape[1:],
      antialias=True,
  )
  # Here we can run our nbody
  final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

  # paint the field
  final_field = mtf.zeros(mesh, shape=hr_shape)
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.pad(final_field, [halo_size, halo_size],
                          block_size_dim.name)
  final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
  # Halo exchange
  for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
    final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                  halo_size)
  # Remove borders
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                            block_size_dim.name)

  #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
  # Hack usisng  custom reshape because mesh is pretty dumb
  final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field],
                              output_dtype=tf.float32,
                              output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                              name='my_dumb_reshape',
                              splittable_dims=part_shape[:-1] + hr_shape[:4])

  return initc, final_field
コード例 #29
0
def transformer_moe_layer_v2(
    inputs, output_dim, hparams, train, variable_dtype,
    layout=None, mesh_shape=None, nonpadding=None):
  """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional mtf.Tensor with shape [a, b, l]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
  if nonpadding is not None:
    nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1],
                           dtype=inputs.dtype) + nonpadding
  insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
  if insert_outer_batch_dim:
    inputs = mtf.reshape(
        inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims)

  assert len(hparams.moe_num_experts) == 2
  a0, b1, l, m = inputs.shape.dims
  hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
  x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
  y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
  x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
  y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
  n = output_dim

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups (g.size) is a multiple of the mesh dimension
  # over which those groups are split.
  num_groups, group_size = _split_into_groups(
      b1.size * l.size, hparams.moe_group_size,
      mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1))
  g1 = mtf.Dimension(b1.name, num_groups)
  g = mtf.Dimension(b1.name + "_unsplit", g1.size)
  s = mtf.Dimension("group_size_x", group_size)

  # Each sequence sends (at most?) expert_capacity positions to each expert.
  # Static expert_capacity dimension is needed for expert batch sizes
  if train:
    capacity_factor = hparams.moe_capacity_factor_train
  else:
    capacity_factor = hparams.moe_capacity_factor_eval
  expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
  expert_capacity = max(expert_capacity, 4)
  c = mtf.Dimension("expert_capacity_x", expert_capacity)

  # We "cheat" here and look at the mesh shape and layout. This is to ensure
  # that the number of groups (h.size) is a multiple of the mesh dimension
  # over which those groups are split.
  num_groups, group_size = _split_into_groups(
      a0.size * g.size * c.size,
      hparams.moe_group_size,
      mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0))
  t = mtf.Dimension("group_size_y", group_size)
  h0 = mtf.Dimension(a0.name, num_groups)
  h = mtf.Dimension(a0.name + "_unsplit", h0.size)

  expert_capacity = min(
      t.size,
      int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
  expert_capacity = max(expert_capacity, 4)
  d = mtf.Dimension("expert_capacity_y", expert_capacity)

  # First level of expert routing
  # Reshape the inner batch size to a multiple of group_dim g1 and
  # group_size_dim s.
  inputs = mtf.reshape(inputs, [a0, g1, s, m])
  if nonpadding is not None:
    nonpadding = mtf.reshape(nonpadding, [a0, g1, s])

  # Get the assignments for the first level.
  # dispatch_tensor_x has shape [a0, g1, s, x, c]
  if hparams.moe_gating == "top_2":
    dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
        inputs=inputs,
        outer_expert_dims=None,
        experts_dim=x,
        expert_capacity_dim=c,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        name="outer_gating",
        importance=nonpadding)
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  # Now create expert_inputs based on the assignments.
  # put num_experts dimension first to make split easier in alltoall
  expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m])

  # we construct an "importance" Tensor for the inputs to the second-level
  # gating.  The importance of an input is 1.0 if it represents the
  # first-choice expert-group and 0.5 if it represents the second-choice expert
  # group.  This is used by the second-level gating.
  importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
  importance = 0.5 * (
      mtf.to_float(mtf.greater(importance, 0.5)) +
      mtf.to_float(mtf.greater(importance, 0.0)))

  # First level, all to all. Here we change the split dimension from g1 to x1.
  expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape(
      [x1, a0, g, c, m]))
  importance = mtf.reshape(importance, [x1, a0, g, c])

  # Second level of expert routing
  # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
  # and group_size_dim t.
  inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
  importance = mtf.reshape(importance, [x1, h0, t])

  # Get the assignments for the second level.
  # dispatch_tensor_y has shape [x1, h0, t, y, d]
  if hparams.moe_gating == "top_2":
    dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
        inputs=inputs_y,
        outer_expert_dims=[x1],
        experts_dim=y,
        expert_capacity_dim=d,
        hparams=hparams,
        train=train,
        variable_dtype=variable_dtype,
        importance=importance,
        name="inner_gating")
  else:
    raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

  # Now create expert_inputs based on the assignments.
  # put num_experts dimension first to make split easier in alltoall
  expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m])

  # Second level, all to all. Here we change the split dimension from h0 to y0.
  expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape(
      [y0, x1, h, d, m]))

  hidden_output = mtf.layers.dense(
      expert_inputs_y, hidden_dim, expert_dims=[y0, x1],
      activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype,
      name="wi")
  expert_output = mtf.layers.dense(
      hidden_output, output_dim, expert_dims=[y0, x1],
      use_bias=False, variable_dtype=variable_dtype,
      name="wo")

  # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
  # expert_output has shape [y0, x1, h, d, n]

  # alltoall
  expert_output = mtf.reshape(expert_output, mtf.Shape(
      [y, x1, h0, d, n]))

  # combine results from inner level
  output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

  # Reshape the combined tensor from inner level to now contain outer_batch_dim
  # a0 and group_dim g
  output = mtf.reshape(output_y, [x1, a0, g, c, n])

  # alltoall from expert_dim x to group_dim g1
  expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

  # combine results from outer level
  output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

  # Reshape the combined tensor to now contain inner_batch_dim
  # b1 and the original sequence length
  output = mtf.reshape(output_x, [a0, b1, l, n])
  if insert_outer_batch_dim:
    output = mtf.reshape(output, [b1, l, n])
  return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
コード例 #30
0
    def fn(x):
        with tf.variable_scope(scope):
            nx = x.shape[-1]  # Grab last dimension from input

            if use_rezero:
                prenorm = identity
            elif use_scale_norm:
                prenorm = scale_norm
            else:
                prenorm = layer_norm

            pre_residual_fn = rezero if use_rezero else identity

            attention_type = params["attention_types"][layer_num]

            if macaron_attention:
                mult = 0.5
                mlp_fn = mlp_glu if use_mlp_glu else mlp
                intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
                # Define intermediate layer of mlp - to split
                dim_intermediate_expanded = mtf.Dimension(
                    "intermediate_expanded", intermediate_size)
                m = mlp_fn(x,
                           "mlp_macaron",
                           dim_intermediate_expanded,
                           variable_dtype=variable_dtype,
                           params=params)

                x = x + (m * mult)
            else:
                mult = 1

            if attention_type != "none":
                res_x = prenorm(x,
                                "norm_1",
                                variable_dtype=variable_dtype,
                                params=params)
                a = attn(res_x,
                         "attn",
                         nx,
                         attention_type=attention_type,
                         params=params,
                         bias=bias,
                         dim_seq=sequence_dim,
                         memory_length_dim=memory_length_dim,
                         variable_dtype=variable_dtype,
                         context=context)
            else:
                a = x

            x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype)

            res_x = prenorm(x,
                            "norm_2",
                            variable_dtype=variable_dtype,
                            params=params)

            if use_moe:
                moe_params = mtf.transformer.moe.HParams()
                mtf.transformer.moe.set_default_moe_hparams(moe_params)
                moe_params.add_hparam("moe_min_expert_capacity", 1)
                moe_params.add_hparam("moe_use_experts_attention", False)

                # Override defaults
                for k, v in params["moe_params"].items():
                    moe_params.add_hparam(k, v)

                moe_train = params["mode"] == "train"

                m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(
                    res_x,
                    x.shape[-1],
                    moe_params,
                    train=moe_train,
                    mesh_shape=params["mesh_shape"],
                    layout=params["layout"],
                    activation=params.get("moe_activation", "relu"),
                    variable_dtype=variable_dtype,
                    num_microbatches=params["num_microbatches"])
                m = mtf.dropout(m,
                                rate=params["res_dropout"],
                                name="moe_dropout")
            else:

                mlp_fn = mlp_glu if use_mlp_glu else mlp
                intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)

                # Define intermediate layer of mlp - to split
                dim_intermediate_expanded = mtf.Dimension(
                    "intermediate_expanded", intermediate_size)

                m = mlp_fn(res_x,
                           "mlp",
                           dim_intermediate_expanded,
                           variable_dtype=variable_dtype,
                           params=params)
                aux_loss = mtf.zeros(x.mesh,
                                     mtf.Shape([]),
                                     dtype=variable_dtype.slice_dtype)

            x = x + pre_residual_fn(
                (m * mult), "norm_rezero_2", variable_dtype)
            return x, aux_loss