def import_to_batch_by_length(x, name):
     return mtf.import_tf_tensor(mesh,
                                 x,
                                 mtf.Shape([batch_dim,
                                            self.length_dim]),
                                 name=name)
示例#2
0
  def __init__(self,
               mesh: mtf.Mesh,
               vocab_dim: mtf.Dimension,
               output_dim: mtf.Dimension,
               variable_dtype: mtf.VariableDType,
               name: str,
               ensemble_dim: mtf.Dimension,
               extra_ids: int = 0,
               dropout_rate: float = 0.0,
               gate_embedding_size: int = gin.REQUIRED,
               frequent_token_fraction: float = 0.1,
               noise_std_dev: float = 0.0):
    """Configurable embedding for the vocabulary.

    Most of the arguments get passed to `mtf.layers.embedding_weights`.

    Mixtape shares gates for low frequency tokens to improve efficiency. Since
    our vocabs are sorted in decreasing order of frequency with sentinels
    appended to the end, we need to do a little trick to ensure that the
    sentinels are treated as high frequency. If you want to treat the sentinels
    as low frequency tokens, then pass in zero for `extra_ids`.

    Args:
      mesh: the mesh used to layout the tensors.
      vocab_dim: the dimension corresponding to vocabulary.
      output_dim: the dimension corresponding to the model hidden states.
      variable_dtype: the datatype information for the  variables used in the
        embedding tensors.
      name: a name to base variable names off of.
      ensemble_dim: the dimension used for ensembling. Absolutely no guarantees
        that this code will work with ensembling.
      extra_ids: a non-negative integer, the number of sentinels at the end of
        the vocab.
      dropout_rate: a float between 0 and 1, the rate to use for dropout.
      gate_embedding_size: a positive integer, the size to use for embedding for
        the gates. It is usually chosen to be much smaller than d_model.
      frequent_token_fraction: a float between 0 and 1, what fraction of tokens
        to consider as high frequency and not share gates for.
      noise_std_dev: a non-negative float, the standard deviation of the
        Gaussian noise to add to the pre-activation priors.
    """
    self._extra_ids = extra_ids
    self._dropout_rate = dropout_rate
    self._noise_std_dev = noise_std_dev
    self._mesh = mesh
    self._vocab_dim = vocab_dim
    self._frequent_vocab_dim = mtf.Dimension(
        vocab_dim.name, int(frequent_token_fraction * vocab_dim.size))
    self._rare_vocab_dim = mtf.Dimension(
        vocab_dim.name, vocab_dim.size - self._frequent_vocab_dim.size)
    self._output_dim = output_dim
    self._copy_output_dim = mtf.Dimension("_{}_copy".format(output_dim.name),
                                          output_dim.size)
    self._pre_gates_dim = mtf.Dimension("gates", 3)
    self._gates_dim = mtf.Dimension("gates", 4)
    self._gate_embedding_dim = mtf.Dimension("gate_embedding",
                                             gate_embedding_size)

    self._embedding_weights = mtf.layers.embedding_weights(
        mesh=mesh,
        vocab_dim=vocab_dim,
        output_dim=output_dim,
        variable_dtype=variable_dtype,
        name="{}_embedding_weights".format(name),
        ensemble_dim=ensemble_dim)
    ensemble_dims = [ensemble_dim] if ensemble_dim else []
    self._context_weights = mtf.layers.embedding_weights(
        mesh=mesh,
        vocab_dim=self._copy_output_dim,
        output_dim=output_dim,
        variable_dtype=variable_dtype,
        name="{}_context_weights".format(name),
        ensemble_dim=ensemble_dims + [self._gates_dim])
    self._context_weights_bias = mtf.get_variable(
        mesh,
        name="{}_context_weights_bias".format(name),
        shape=mtf.Shape(ensemble_dims + [self._gates_dim, output_dim]),
        dtype=variable_dtype,
        initializer=tf.zeros_initializer())

    self._prior_weights = mtf.layers.embedding_weights(
        mesh=mesh,
        vocab_dim=self._gate_embedding_dim,
        output_dim=output_dim,
        variable_dtype=variable_dtype,
        name="{}_prior_weights".format(name),
        ensemble_dim=ensemble_dims + [self._pre_gates_dim])
    self._prior_weights_bias = mtf.get_variable(
        mesh,
        name="{}_prior_weights_bias".format(name),
        shape=mtf.Shape(ensemble_dims +
                        [self._pre_gates_dim, self._gate_embedding_dim]),
        dtype=variable_dtype,
        initializer=tf.zeros_initializer())
    self._prior_vocab_vector = mtf.get_variable(
        mesh,
        name="{}_prior_vocab_vector".format(name),
        shape=mtf.Shape(ensemble_dims +
                        [self._frequent_vocab_dim, self._gate_embedding_dim]),
        dtype=variable_dtype,
        initializer=tf.random_normal_initializer())
    self._prior_gates_vector = mtf.get_variable(
        mesh,
        name="{}_prior_gates_vector".format(name),
        shape=mtf.Shape(ensemble_dims + [self._pre_gates_dim, output_dim]),
        dtype=variable_dtype,
        initializer=tf.random_normal_initializer())
    self._prior_bias = mtf.get_variable(
        mesh,
        name="{}_prior_bias".format(name),
        shape=mtf.Shape(ensemble_dims +
                        [self._frequent_vocab_dim, self._pre_gates_dim]),
        dtype=variable_dtype,
        initializer=tf.random_normal_initializer())
    def test_get_laidout_tensors(self, is_eval_mode):
        mesh_shape = "mesh_x:2, mesh_y:1"
        layout = "batch:mesh_x, io:mesh_y"
        batch_io_dim = 4

        with tf.Session() as sess:
            topology, num_cores = self.initialize_system(sess)

            # Get a device_assignment object for mtf.
            d_assignment = device_assignment.device_assignment(
                topology,
                computation_shape=[
                    1,
                ] * mtf.utils.topology_rank(topology),
                num_replicas=num_cores)

            # Hacked dataset creator: creates different datasets for the first and
            # second call, in order to test SimdMeshImplInputReader.
            self.sub_batch_created_times = 0

            def stateful_ds_creator():
                whole_batch = tf.eye(batch_io_dim, dtype=tf.float32)
                sub_batch = tf.slice(whole_batch,
                                     [self.sub_batch_created_times * 2, 0],
                                     [2, 4])
                self.sub_batch_created_times += 1
                return tf.data.Dataset.from_tensors(
                    sub_batch).repeat().unbatch()

            batch_dim = mtf.Dimension("batch", batch_io_dim)
            io_dim = mtf.Dimension("io", batch_io_dim)
            mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])]

            # Get mesh_impl.
            mesh_shape = mtf.convert_to_shape(mesh_shape)
            layout_rules = mtf.convert_to_layout_rules(layout)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, None, d_assignment)

            simd_input_reader = input_reader.SimdMeshImplInputReader(
                mesh_impl,
                stateful_ds_creator,
                mtf_input_shapes,
                external_worker=False,
                is_eval_mode=is_eval_mode)

            def model_fn(features):
                return features

            replicated_computation = tpu.replicate(
                computation=model_fn,
                inputs=[[]] * num_cores,
                infeed_queue=simd_input_reader.infeed_queue,
                device_assignment=d_assignment)

            simd_input_reader.start_infeed_thread(sess, 1)
            results = sess.run(replicated_computation)
            print("results: {}".format(results))

            core_0_data = results[0][0]
            core_1_data = results[1][0]
            print("core_0_data: {}".format(core_0_data))
            print("core_1_data: {}".format(core_1_data))

            if is_eval_mode:
                # If there is only one dataset object, then the stateful_ds_creator()
                # should be called only once.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_1_data)
            else:
                # If there are two dataset objects, then the stateful_ds_creator()
                # should be called twice.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32),
                    core_1_data)

            sess.run(tf.tpu.shutdown_system())
示例#4
0
文件: bert.py 项目: tensorflow/mesh
    def __init__(self,
                 config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 scope=None,
                 mesh_shape="",
                 layout=""):
        self.config = copy.deepcopy(config)
        del config
        if not is_training:
            self.config.layer_output_dropout_prob = 0.0
            self.config.attention_probs_dropout_prob = 0.0
            self.config.feedforward_intermediate_dropout_prob = 0.0
        input_shape = input_ids.shape
        assert input_shape.ndims == 2

        self._seq_dim = input_shape.dims[1]
        self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
        self._extra_losses = []
        mesh = input_ids.mesh

        if token_type_ids is None:
            token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

        with tf.variable_scope(scope, default_name="bert"):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                self.embedding_table = mtf.get_variable(
                    mesh,
                    "word_embeddings",
                    mtf.Shape([self.vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                self.word_embedding_output = mtf.gather(
                    self.embedding_table, input_ids, self.vocab_dim)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = self.word_embedding_output

                token_type_table = mtf.get_variable(
                    mesh,
                    "token_type_embeddings",
                    mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                if token_type_ids is not None:
                    self.embedding_output += mtf.gather(
                        token_type_table, token_type_ids,
                        self.token_type_vocab_dim)
                if self.config.position_signal == "embedding":
                    full_position_table = mtf.get_variable(
                        mesh,
                        "position_embeddings",
                        mtf.Shape(
                            [self.max_position_embeddings_dim,
                             self.model_dim]),
                        initializer=self.embedding_initializer)
                    short_position_table = mtf.rename_dimension(
                        mtf.slice(full_position_table, 0, self.seq_dim.size,
                                  self.max_position_embeddings_dim.name),
                        self.max_position_embeddings_dim.name,
                        self.seq_dim.name)
                    self.embedding_output += short_position_table
                self.embedding_output = self.normalize(self.embedding_output)
                self.embedding_output = mtf.dropout(
                    self.embedding_output,
                    is_training,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)

            with tf.variable_scope("encoder"):
                attention_biases = []
                if input_mask:
                    # [batch_dim, memory_seq_dim]
                    attention_biases.append((1.0 - mtf.to_float(
                        mtf.replace_dimensions(input_mask, self.seq_dim,
                                               self.memory_seq_dim))) *
                                            -10000.0)
                if self.config.position_signal == "relative_attention_bias":
                    buckets_dim = mtf.Dimension("buckets", 32)
                    rp_bucket = _relative_position_bucket(
                        mtf.range(mesh, self.memory_seq_dim, tf.int32) -
                        mtf.range(mesh, self.seq_dim, tf.int32),
                        num_buckets=buckets_dim.size)
                    bias_var = mtf.get_variable(
                        mesh,
                        "relative_attention_bias",
                        [self.num_heads_dim, buckets_dim],
                        initializer=tf.zeros_initializer())
                    attention_biases.append(
                        mtf.gather(bias_var, rp_bucket, buckets_dim))
                attention_bias = mtf.add_n(attention_biases)
                prev_layer_output = self.embedding_output
                self.all_encoder_layers = []
                for block_num in range(self.config.num_blocks):
                    with tf.variable_scope("block_%d" % block_num):
                        for layer_idx, layer_type in enumerate(
                                self.config.block_layers):
                            layer_name = layer_type
                            count = self.config.block_layers[:layer_idx].count(
                                layer_type)
                            if count:
                                layer_name += "_%d" % count
                            with tf.variable_scope(layer_name):
                                x = prev_layer_output
                                if self.config.residual_structure == "direct":
                                    x = self.normalize(x)
                                if layer_type == "attention":
                                    x = self.self_attention(x, attention_bias)
                                elif layer_type == "feedforward":
                                    x = self.feedforward(x)
                                elif layer_type == "moe":
                                    x = self.moe(x, layout, mesh_shape,
                                                 input_mask, is_training)
                                else:
                                    raise ValueError("unknown layer type " +
                                                     layer_type)
                                x = mtf.dropout(
                                    x,
                                    is_training,
                                    keep_prob=1.0 -
                                    self.config.layer_output_dropout_prob)
                                layer_output = prev_layer_output + x
                                if self.config.residual_structure == "original":
                                    layer_output = self.normalize(layer_output)
                                prev_layer_output = layer_output
                    self.all_encoder_layers.append(layer_output)

            self.sequence_output = prev_layer_output
            if self.config.residual_structure == "direct":
                self.sequence_output = self.normalize(self.sequence_output)

            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_dim, seq_dim, hidden_size] to a tensor of shape
            # [batch_dim, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = mtf.gather(self.sequence_output, 0,
                                                self.seq_dim)
                self.pooled_output = mtf.layers.dense(
                    first_token_tensor,
                    reduced_dims=[self.model_dim],
                    new_dims=[self.model_dim],
                    activation=mtf.tanh,
                    kernel_initializer=self.dense_initializer,
                    use_bias=self.config.use_bias)
示例#5
0
文件: attention.py 项目: jahau/mesh
def _combined_dim(dims):
    return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)
示例#6
0
 def testReduceOperation(self):
     reduce_operation = mtf.ReduceOperation(self.x, mtf.Shape([self.b_dim]),
                                            "sum")
     self.assertEqual(reduce_operation.splittable_dims,
                      frozenset(["a", "b"]))
     self.assertEqual(reduce_operation.unsplittable_dims, frozenset())
示例#7
0
 def testConvertToShape(self, inputs):
     shape = mtf.convert_to_shape(inputs)
     self.assertEqual(
         shape, mtf.Shape([mtf.Dimension("x", 4),
                           mtf.Dimension("y", 8)]))
示例#8
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             master_dtype=tf.bfloat16,
                             slice_dtype=tf.float32):
    """Local mixture of experts that works well on TPU.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  The number of parameters in the gating network is:
    (input_dim.size * hparams.num_experts) +

  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype

  Returns:
    outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    orig_inputs = inputs
    input_dim = inputs.shape.dims[-1]
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)
    group_size_dim = mtf.Dimension("group", hparams.moe_group_size)
    batch_dim = mtf.Dimension(
        orig_inputs.shape[0].name,
        orig_inputs.shape.size // (group_size_dim.size * input_dim.size))
    inputs = mtf.reshape(inputs, [batch_dim, group_size_dim, input_dim])

    # Each sequence sends expert_capacity positions to each expert.
    capacity_factor = (hparams.moe_capacity_factor_train
                       if train else hparams.moe_capacity_factor_eval)
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)

    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size)

    if hparams.moe_gating == "top_2":
        dispatch_tensor, combine_tensor, loss = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # put num_experts dimension first to make split easier in alltoall
    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   experts_dim_unsplit, batch_dim,
                                   expert_capacity_dim, input_dim
                               ]))

    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape(
            [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim]))

    # Now feed the expert inputs through the experts.
    h = mtf.layers.dense(expert_inputs,
                         hidden_dim,
                         expert_dims=[experts_dim],
                         activation=mtf.relu,
                         use_bias=False,
                         master_dtype=master_dtype,
                         slice_dtype=slice_dtype,
                         name="x0")
    expert_output = mtf.layers.dense(h,
                                     output_dim,
                                     expert_dims=[experts_dim],
                                     use_bias=False,
                                     master_dtype=master_dtype,
                                     slice_dtype=slice_dtype,
                                     name="x1")

    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape(
            [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim]))

    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape([batch_dim, group_size_dim, output_dim]))

    output = mtf.reshape(output, orig_inputs.shape.dims[:-1] + [output_dim])

    return output, loss * hparams.moe_loss_coef
示例#9
0
    def _mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        hparams = self._hparams
        targets = tf.to_int32(features["targets"])
        if len(targets.get_shape()) > 2:
            tf.logging.info("targets = %s" % targets)
            targets = tf.squeeze(targets, [2, 3])
        # pad targets to max_length
        def pad_to_max_length(x):
            extra_length = hparams.max_length - tf.shape(x)[1]
            x = tf.pad(x, [[0, 0], [0, extra_length]])
            x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
            return x

        targets = pad_to_max_length(targets)
        for key in [
                "targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"
        ]:
            if key in features:
                features[key] = pad_to_max_length(features[key])
        shifted_targets = common_layers.shift_right_2d(targets)

        targets = self._import_to_batch_by_length(targets, "targets", mesh,
                                                  hparams)
        shifted_targets = self._import_to_batch_by_length(
            shifted_targets, "shifted_targets", mesh, hparams)

        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = self._import_to_batch_by_length(
                features["targets_segmentation"], "targets_segmentation", mesh,
                hparams)
            targets_position = self._import_to_batch_by_length(
                features["targets_position"], "targets_position", mesh,
                hparams)
            decoder_self_attention_mask = (
                mtf.layers.attention_mask_autoregressive(
                    targets_position, dtype=self.activation_dtype) +
                mtf.layers.attention_mask_same_segment(
                    targets_segmentation, dtype=self.activation_dtype))
        else:
            targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
            decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
                targets_position, dtype=self.activation_dtype)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        extra_losses = []
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "decoder":
            encoder_output = None
            encoder_decoder_attention_mask = None
        else:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = pad_to_max_length(inputs)
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = self._import_to_batch_by_length(
                    features["inputs_segmentation"], "inputs_segmentation",
                    mesh, hparams)
                inputs_position = self._import_to_batch_by_length(
                    features["inputs_position"], "inputs_position", mesh,
                    hparams)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        inputs_segmentation, dtype=self.activation_dtype))
            else:
                inputs_position = mtf.range(mesh,
                                            self.length_dim,
                                            dtype=tf.int32)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_ignore_padding(
                        inputs, dtype=self.activation_dtype))

            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.gather(positional_embedding_var, inputs_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_self_attention_mask,
                    losses=extra_losses)

        if hparams.transformer_type == "encdec":
            if "inputs_segmentation" in features:
                encoder_decoder_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        targets_segmentation,
                        inputs_segmentation,
                        dtype=self.activation_dtype))
            else:
                encoder_decoder_attention_mask = encoder_self_attention_mask
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)

        if hparams.transformer_type != "encoder":
            # DECODER
            x = (mtf.gather(targets_embedding_var, shifted_targets,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, targets_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("decoder"):
                x = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encoder_output=encoder_output,
                    self_attention_mask=decoder_self_attention_mask,
                    encdec_attention_mask=encoder_decoder_attention_mask,
                    losses=extra_losses)
        logits = mtf.matmul(x, softmax_var)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
        off_value = hparams.label_smoothing / self._targets_vocab_size
        on_value = 1.0 - hparams.label_smoothing + off_value
        soft_targets = mtf.one_hot(targets,
                                   self.targets_vocab_dim,
                                   on_value=on_value,
                                   off_value=off_value,
                                   dtype=self.activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.targets_vocab_dim)
        weights = mtf.layers.weights_nonzero(targets,
                                             dtype=self.activation_dtype)
        loss = mtf.reduce_mean(loss * weights)
        for l in extra_losses:
            loss += l
        logits = mtf.to_float(logits)
        # combine batch dims
        if len(self.batch_dims) > 1:
            combined_batch_dim = mtf.Dimension(self.batch_dims[0].name,
                                               mtf.Shape(self.batch_dims).size)
            logits = mtf.reshape(logits,
                                 [combined_batch_dim] + logits.shape.dims[-2:])
        return logits, loss
示例#10
0
    def test_get_indices(self):
        key_size = 2
        n_keys = 3
        product_size = 2
        head_size = 2
        batch = 2
        seq_len = 2
        knn = 2

        n_key_dim = mtf.Dimension("n_keys", n_keys)
        key_dim = mtf.Dimension("key", key_size // 2)
        seq_dim = mtf.Dimension("length", seq_len)
        batch_dim = mtf.Dimension("batch", batch)
        head_dim = mtf.Dimension("n_heads", head_size)
        product_dim = mtf.Dimension("product_key", product_size)
        knn_dim = mtf.Dimension("knn", knn)

        query_shape = mtf.Shape(
            [batch_dim, seq_dim, head_dim, product_dim, key_dim])
        keys_shape = mtf.Shape([head_dim, product_dim, n_key_dim, key_dim])
        query = mtf.ones(self.mesh, query_shape)

        keys_vals = [
            [
                [[4], [1], [2]],
                [[2], [-1], [2]],
            ],
            [
                [[1], [2], [5]],
                [[6], [1], [4]],
            ],
        ]
        # h1:
        #   First scores:
        #   [4, 2]
        #   [2, 2]
        #   Cartesian added scores:
        #   [6, 6]
        #   Indices:
        #   [0, 2]    [0*n_k + 0, 0*n_k + 2]
        # h2:
        #   First scores:
        #   [5, 2]
        #   [6, 4]
        #   Cartesian added scores:
        #   [11, 9]
        #   Indices:
        #   [6, 8]   [2*n_k+0, 2*n_k+2]
        expected_scores = np.broadcast_to(np.array([[6, 6], [11, 9]]),
                                          [batch, seq_len, head_size, knn])
        expected_indices = np.broadcast_to(np.array([[0, 2], [6, 8]]),
                                           [batch, seq_len, head_size, knn])

        keys = mtf.constant(self.mesh, keys_vals, keys_shape)

        pkm = memory_layers.ProductKeyValueMemory(key_size, n_keys, head_size,
                                                  knn)
        mtf_scores, mtf_indices = pkm.get_indices(keys, query)

        # Shapes.
        expected_shape = mtf.Shape([batch_dim, seq_dim, head_dim, knn_dim])
        self.assertEqual(expected_shape, mtf_scores.shape)
        self.assertEqual(expected_shape, mtf_indices.shape)

        # Values
        lowering_s, scores = self._export_to_tf_tensor(mtf_scores)
        lowering_i, indices = self._export_to_tf_tensor(mtf_indices)
        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering_s.copy_masters_to_slices())
        self.evaluate(lowering_i.copy_masters_to_slices())
        scores, indices = self.evaluate([scores, indices])

        self.assertAllEqual(expected_scores, scores)
        self.assertAllEqual(expected_indices, indices)
示例#11
0
def transformer_moe_layer_v2(inputs,
                             output_dim,
                             hparams,
                             train,
                             master_dtype=tf.bfloat16,
                             slice_dtype=tf.float32):
    """2-level mixture of experts.

  Adapted from the paper https://arxiv.org/abs/1701.06538

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_num_experts: number of experts
    hparams.moe_hidden_size: size of hidden layer in each expert
    hparams.moe_group_size: size of each "group" for gating purposes
    hparams.moe_capacity_factor_train: a float
    hparams.moe_capacity_factor_eval: a float
    hparams.moe_capacity_factor_second_level: a float
    hparams.moe_gating: a string
    + all hyperparmeters used by _top_2_gating()

  One set of params for experts in first level and different of hparams
  per expert in the second level.
  The number of parameters in the gating network is:
    (input_dim.size * (hparams.num_experts) +
      (moe_hidden_size * hparams.num_experts) * hparams.num_experts


  The number of parameters in the experts themselves is:
    (hparams.num_experts
     * (input_dim.size + output_dim.size)
     * hparams.moe_hidden_size)

  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-3 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Several hacks are necessary to get around current TPU limitations:

  - To ensure static shapes, we enforce (by truncation/padding)
    that each sequence send the same number of elements to each expert.

    It would make more sense to enforce this equality over the entire batch,
    but due to our hacked-up gather-by-matmul implementation, we need to divide
    the batch into "groups".  For each group, the same number of elements
    are sent to each expert.

  TODO(noam): Factor this code better.  We want to be able to substitute
  different code for the experts themselves.

  Dimensions cheat sheet:
  a, b: batch size
  l: original sequence length
  m: input depth
  n: output depth
  g, h: number of groups
  s, t: group size
  x, y: number of experts
  c, d: expert capacity

  input: [a0, b1, l, m]
  input: [a0, g1, s, m]
  dispatch_tensor_x: [a0, g1, s, x, c]
  expert_input: [a0, g1, x, c, m]
  alltoall: [a0, g, x1, c, m]
  alltoall: [a0, g, x1, c, m]
  transpose: [x1, a0, g, c, m]
  reshape: [x1, h0, s, m]
  assignment2: [x1, h0, t, y, d]
  expert_input2: [x1, h0, y, d, m]
  alltoall: [x1, h, y0, d, m]
  ...
  reverse of that

  gating params 0: [m, x]
  gating params 1: [x1, m, y]

  expert params:
     [x1, y0, m, hidden]
     [x1, y0, hidden, n]

  Args:
    inputs: a mtf.Tensor with shape [a, b, l, m]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype

  Returns:
    outputs: a Tensor with shape [a, b, l, n]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    insert_outer_batch_dim = (len(inputs.shape.dims) == 3)
    if insert_outer_batch_dim:
        inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] +
                             inputs.shape.dims)

    assert len(hparams.moe_num_experts) == 2
    a0, b1, l, m = inputs.shape.dims
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)
    x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0])
    y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1])
    x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0])
    y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1])
    n = output_dim

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (g.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        b1.size * l.size, hparams.moe_group_size,
        mtf.tensor_dim_to_mesh_dim_size(hparams.layout, hparams.mesh_shape,
                                        b1))
    g1 = mtf.Dimension(b1.name, num_groups)
    g = mtf.Dimension(b1.name + "_unsplit", g1.size)
    s = mtf.Dimension("group_size_x", group_size)

    # Each sequence sends (at most?) expert_capacity positions to each expert.
    # Static expert_capacity dimension is needed for expert batch sizes
    capacity_factor = (hparams.moe_capacity_factor_train
                       if train else hparams.moe_capacity_factor_eval)
    expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
    expert_capacity = max(expert_capacity, 4)
    c = mtf.Dimension("expert_capacity_x", expert_capacity)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups (h.size) is a multiple of the mesh dimension
    # over which those groups are split.
    num_groups, group_size = _split_into_groups(
        a0.size * g.size * c.size, hparams.moe_group_size,
        mtf.tensor_dim_to_mesh_dim_size(hparams.layout, hparams.mesh_shape,
                                        a0))
    t = mtf.Dimension("group_size_y", group_size)
    h0 = mtf.Dimension(a0.name, num_groups)
    h = mtf.Dimension(a0.name + "_unsplit", h0.size)

    expert_capacity = min(
        t.size,
        int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
    expert_capacity = max(expert_capacity, 4)
    d = mtf.Dimension("expert_capacity_y", expert_capacity)

    # First level of expert routing
    # Reshape the inner batch size to a multiple of group_dim g1 and
    # group_size_dim s.
    inputs = mtf.reshape(inputs, [a0, g1, s, m])

    # Get the assignments for the first level.
    # dispatch_tensor_x has shape [a0, g1, s, x, c]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating(
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=x,
            expert_capacity_dim=c,
            hparams=hparams,
            train=train)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x],
                                 [x, a0, g1, c, m])

    # we construct an "importance" Tensor for the inputs to the second-level
    # gating.  The importance of an input is 1.0 if it represents the
    # first-choice expert-group and 0.5 if it represents the second-choice expert
    # group.  This is used by the second-level gating.
    importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c])
    importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) +
                        mtf.to_float(mtf.greater(importance, 0.0)))

    # First level, all to all. Here we change the split dimension from g1 to x1.
    expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c,
                                                              m]))
    importance = mtf.reshape(importance, [x1, a0, g, c])

    # Second level of expert routing
    # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0
    # and group_size_dim t.
    inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m])
    importance = mtf.reshape(importance, [x1, h0, t])

    # Get the assignments for the second level.
    # dispatch_tensor_y has shape [x1, h0, t, y, d]
    if hparams.moe_gating == "top_2":
        dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating(
            inputs=inputs_y,
            outer_expert_dims=[x1],
            experts_dim=y,
            expert_capacity_dim=d,
            hparams=hparams,
            train=train,
            importance=importance)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    # Now create expert_inputs based on the assignments.
    # put num_experts dimension first to make split easier in alltoall
    expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y],
                                 [y, x1, h0, d, m])

    # Second level, all to all. Here we change the split dimension from h0 to y0.
    expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d,
                                                              m]))

    hidden_output = mtf.layers.dense(expert_inputs_y,
                                     hidden_dim,
                                     expert_dims=[y0, x1],
                                     activation=mtf.relu,
                                     use_bias=False,
                                     master_dtype=master_dtype,
                                     slice_dtype=slice_dtype,
                                     name="expert0")
    expert_output = mtf.layers.dense(hidden_output,
                                     output_dim,
                                     expert_dims=[y0, x1],
                                     use_bias=False,
                                     master_dtype=master_dtype,
                                     slice_dtype=slice_dtype,
                                     name="expert1")

    # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
    # expert_output has shape [y0, x1, h, d, n]

    # alltoall
    expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n]))

    # combine results from inner level
    output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n])

    # Reshape the combined tensor from inner level to now contain outer_batch_dim
    # a0 and group_dim g
    output = mtf.reshape(output_y, [x1, a0, g, c, n])

    # alltoall from expert_dim x to group_dim g1
    expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n]))

    # combine results from outer level
    output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n])

    # Reshape the combined tensor to now contain inner_batch_dim
    # b1 and the original sequence length
    output = mtf.reshape(output_x, [a0, b1, l, n])
    if insert_outer_batch_dim:
        output = mtf.reshape(output, [b1, l, n])
    return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
示例#12
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None,
                             activation=mtf.relu,
                             num_microbatches=None,
                             token_embeddings=None,
                             context=None):
    """Local heterogenous mixture of experts.

  See transformer_moe_layer_v1 in moe.py for a more detailed explanation for
  a generic moe layer.

  The heterogeneous mask outputted by generate_heterogeneous_expert_masks has
  dimension [maximum hidden size, maximum # layers, # experts] and its shape
  will overwrite the parameters moe_num_layers and moe_hidden_size in hparams.
  The layer-specific mask slice is applied at each expert layer to the
  activation which is [expert width, # experts]. If the heterogeneous_mask_info
  is None, there is no mask applied and the code is equivalent to the
  homogeneous case.


  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.
    num_microbatches: number of microbatches.
    token_embeddings: a mtf.Tensor with shape
      [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
      that correspond to the inputs. These can optionally be used to make
      routing decisions.
    context: a Context.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    orig_inputs = inputs

    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

    if hparams.moe_heterogeneous_mask_info is not None:
        tf.logging.info("moe_heterogeneous_mask_info: {}".format(
            hparams.moe_heterogeneous_mask_info))
        heterogeneous_mask = generate_heterogeneous_expert_masks(
            hparams.moe_heterogeneous_mask_info,
            hparams.moe_num_experts,
            experts_dim,
            mesh=inputs.mesh,
            expert_width=hparams.moe_hidden_size)
        # overwrite depth and width with the mask maximum dimension
        hparams.moe_num_layers = heterogeneous_mask.shape[1].size
        hparams.moe_hidden_size = heterogeneous_mask.shape[0].size
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups is a multiple of the mesh dimension
    # over which those groups are split.
    batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                        orig_inputs.shape.dims[-1])
    # Hack: we assume that
    #   "outer_batch" == replication of experts
    #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
    #
    # We then reqire num_groups to be a multiple of mesh_dim_size.
    if orig_inputs.shape.dims[0].name == "outer_batch":
        outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
    else:
        outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                           orig_inputs.shape.dims[0])

    # Number of MoE inputs (total number of position across batch_and_length_dims
    # per replica.
    n = 1
    for d in batch_and_length_dims:
        n *= d.size

    n = n // outer_batch_dim.size

    mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                    orig_batch_dim)
    num_groups, group_size = moe._split_into_groups(  # pylint: disable=protected-access
        n, hparams.moe_group_size, mesh_dim_size)
    # TODO(barretzoph): implementation without pylint calls?

    group_size_dim = mtf.Dimension("group", group_size)
    num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

    moe_input_dims = [
        outer_batch_dim, num_groups_dim, group_size_dim, input_dim
    ]
    # OGSM Tensor
    inputs = mtf.reshape(inputs, moe_input_dims)

    # Token embeddings that can be optionally used in the router for determining
    # where to send tokens.
    if hparams.moe_word_embed_mode is not None:
        token_embeddings = mtf.cast(
            mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

    # Each sequence sends expert_capacity positions to each expert.
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
    tf.logging.info("expert_capacity: %d" % expert_capacity)
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               batch_and_length_dims,
                               dtype=inputs.dtype) + nonpadding
        nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
    if hparams.moe_gating == "top_2":
        # combine_tensor,
        # dispatch_tensor  OG`SEC Tensors
        # (G is generally split along mesh dim)
        dispatch_tensor, combine_tensor, loss = moe._top_2_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "top_n":
        dispatch_tensor, combine_tensor, loss = moe._top_n_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch":
        dispatch_tensor, combine_tensor, loss = moe._switch_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "ntlb":
        dispatch_tensor, combine_tensor, loss = moe._ntlb_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch_max":
        dispatch_tensor, combine_tensor, loss = moe._switch_max_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "expert_selection":
        dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            group_size_dim=group_size_dim,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            name="expert_selection_gating",
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   outer_batch_dim, experts_dim_unsplit,
                                   num_groups_dim, expert_capacity_dim,
                                   input_dim
                               ]))

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over batch -> split over experts
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, input_dim
        ]))

    # Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts]
    for layer_idx in range(hparams.moe_num_layers):
        with tf.variable_scope("expert_layer_{}".format(layer_idx)):
            res_h = 0.0
            if layer_idx > 0:
                res_h = expert_inputs
                expert_inputs = transformer.sublayer_rms_norm(
                    expert_inputs, None, context)

            # Now feed the expert inputs through the experts.
            h = mtf.layers.dense_product(
                expert_inputs,
                reduced_dims=expert_inputs.shape.dims[-1:],
                new_dims=[hidden_dim],
                expert_dims=[experts_dim],
                activation_functions=activation,
                use_bias=False,
                variable_dtype=variable_dtype,
                name="wi")

            # apply dropout
            if hparams.moe_dropout_rate != 0.0:
                h = mtf.dropout(h,
                                is_training=train,
                                keep_prob=1.0 - hparams.moe_dropout_rate)
            # only if heterogeneous
            if hparams.moe_heterogeneous_mask_info is not None:
                # Get mask for current layer by slicing heterogeneous mask
                heterogeneous_mask_slice = mtf.slice(heterogeneous_mask,
                                                     layer_idx, 1,
                                                     "num_expert_layers")

                # Get rid of the expert layers dimension.
                heterogeneous_mask_slice = mtf.reshape(
                    heterogeneous_mask_slice, [
                        heterogeneous_mask_slice.shape[0],
                        heterogeneous_mask_slice.shape[-1]
                    ])
                h *= mtf.cast(heterogeneous_mask_slice, h.dtype)
            expert_output = mtf.layers.dense(h,
                                             output_dim,
                                             expert_dims=[experts_dim],
                                             use_bias=False,
                                             reduced_dims=h.shape.dims[-1:],
                                             variable_dtype=variable_dtype,
                                             name="wo")

            if layer_idx < (hparams.moe_num_layers - 1):
                expert_output = transformer.sublayer_dropout(
                    expert_output, None, context)
            expert_output += res_h
            expert_inputs = expert_output

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim, experts_dim_unsplit, num_groups_dim,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over experts -> split over batch
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
    return output, loss * hparams.moe_loss_coef
示例#13
0
def model_fn(features, labels, mode, params):
    # Get global step
    global_step = tf.train.get_global_step()

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    if mode == tf.estimator.ModeKeys.PREDICT:
        params["layout"] = remove_batch_from_layout(params["layout"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])
    
    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Trainable variable precision
    # Store to checkpoints in master type, train in slice type, compute in activation type
    if params["precision"] == "bfloat16":
        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16)
    else:
        variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"inputs": features, "labels": labels}
    sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]}

    params = add_mode_to_params(params, mode)
    batch_size = get_batch_size(params)

    batch_dim = mtf.Dimension("batch", batch_size)
    batch_dims = [batch_dim]
    feature_length = sequence_length_dict["inputs"]
    length_dim = mtf.Dimension("sequence", feature_length)

    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            feature_shape = mtf.Shape(batch_dims + [length_dim])
            if type(features_dict[key]) == dict:
                features_dict[key] = features_dict[key]["feature"]
            x = tf.cast(features_dict[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            mtf_features[key] = mtf.import_fully_replicated(
                mesh, x, feature_shape, name=key)

    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
    other_features = {}
    memory_length_dim = mtf.Dimension("memory_length", length_dim.size)

    attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None

    # Add attn_bias into mtf_features
    other_features["attn_bias"] = attn_bias

    # Define other Dimensions that we'll need inside the model
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
    # We need this because gathering when both the args have the same dimension in them breaks things
    # This dim is specifically for the weights
    # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
    embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])

    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Set up the model for prediction
        inputs = mtf_features["inputs"]
        if params["remove_partial_sequences"] is None:
            params["remove_partial_sequences"] = False

        export = params.get("export", False)

        if not export:
            mtf_samples = sample_autoregressive(
                inputs, other_features=other_features, params=params, variable_dtype=variable_dtype,
                remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax'])

        else:
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None)

        mtf_samples = mtf.anonymize(mtf_samples)
        inputs = mtf.anonymize(inputs)
        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
        inputs = lowering.export_to_tf_tensor(inputs)
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        predictions = {
            "inputs": inputs,
            "outputs": outputs}
        
        def scaffold_fn():
            return tf.train.Scaffold(
                local_init_op=tf.group(
                    tf.train.Scaffold.default_local_init_op(),
                    lowering.copy_masters_to_slices(),
                    name="mtf_local_init_op"),
                ready_op=tf.concat(
                    [tf.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()],
                    axis=0,
                    name="mtf_ready_op"))

        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            scaffold_fn=scaffold_fn,
            prediction_hooks=[mtf.MtfRestoreHook(lowering)])

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL)

    # Gets number of microbatches per batch for serialized training
    # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
    num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim,
                                                                        sequence_length=sequence_length_dict,
                                                                        mesh_shape=mesh_shape,
                                                                        layout_rules=layout_rules,
                                                                        tokens_per_microbatch_per_replica=params["tokens_per_mb_per_replica"]))
    params["num_microbatches"] = num_microbatches  # Add num microbatches to params
    
    if num_microbatches > 1:
        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            if params["model"] == "GPT":
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, variable_dtype=variable_dtype)
                return {"logits": logits, "loss": loss, "loss_batch": loss_batch}
            else:
                raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
        logits = output_dict["logits"]
    else:
        # If we're not splitting into microbatches, return logits & loss as is
        if params["model"] == "GPT":
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None)
        else:
            raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]")

    # Auto layout generation
    if params["auto_layout"]:
        auto_layout(graph, mesh_shape, logits, loss)
    if params["auto_layout_and_mesh_shape"]:
        auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if params["num_microbatches"] > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype)
        # Log summaries to tensorboard
        mtf.scalar_summary("loss", loss)
        # Log gradients if in params
        if params["log_grads"] not in [None, False]:
            for g in var_grads:
                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
                mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
    else:
        # For now, we can only export fully-replicated tensors.
        # This has to be done before lowering or they will not be included in the graph
        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
        max_logits = mtf.argmax(logits, vocab_dim)
        fully_replicated_mean_logits = mtf.anonymize(mean_logits)
        fully_replicated_max_logits = mtf.anonymize(max_logits)
        fully_replicated_loss_batch = mtf.anonymize(loss_batch)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))  # Need to manually increment global_step
        tf.logging.info(f"tf_update_ops: {tf_update_ops}")
        train_op = tf.group(tf_update_ops)
    else:
        tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits)
        tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits)
        tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch))

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(
                tf.global_variables(),
                sharded=True,
                max_to_keep=10,
                keep_checkpoint_every_n_hours=2,
                defer_build=False,
                save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Evaluation metrics
            def _perplexity(tf_loss_batch):
                loss = tf.reduce_mean(tf_loss_batch)
                loss /= params["num_microbatches"]
                perplexity = tf.exp(loss)
                return tf.metrics.mean(perplexity)

            def _metric_fn(tf_mean_logits, tf_loss_batch):
                mean_logits = tf.metrics.mean(tf_mean_logits)
                perp = _perplexity(tf_loss_batch)
                return {"mean_logits": mean_logits, "perplexity": perp}

            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
                eos_token = params["eos_id"]
                answer_positions = tf.where(tf.math.not_equal(labels, eos_token))

                correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions)
                accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32))

                # I guess tf_loss_batch has z_loss and maybe other stuff added to it
                # so maybe this should be calculated separately in the future
                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
                log_perplexity = tf.metrics.mean(answer_loss)

                return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity}

            eval_task = params["eval_task"]
            if eval_task == "lambada":
                eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch])
            else:
                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
def layer_prepostprocess_dropout(x, hparams):
    batch_dim = x.shape.dims[0]
    model_dim = x.shape.dims[-1]
    return mtf.dropout(x,
                       keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                       noise_shape=mtf.Shape([batch_dim, model_dim]))
示例#15
0
    def decode(self,
               inputs,
               variable_dtype=mtf.VariableDType(tf.float32),
               beam_size=1,
               alpha=0.6,
               temperature=1.0,
               decode_length_multiplier=1.5,
               decode_length_constant=10):
        """Sampling or beam search.

    TODO(noam): should we make the output length dimension different from the
    input length dimension?

    Args:
      inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim]
      variable_dtype: a mtf.VariableDType
      beam_size: an integer >= 1
      alpha: a floating point value (length bonus for beam search)
      temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
      decode_length_multiplier: a float
      decode_length_constant: a float

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_sequence_id = mtf.minimum(inputs, 1)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs=inputs,
            targets=None,
            compute_loss=False,
            mode=tf.estimator.ModeKeys.PREDICT,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            shared_params=shared_params)
        del encoder_loss
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)
        encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
            encoder_sequence_id)
        if beam_size == 1:
            ids_shape = inputs.shape
            partial_sequences = mtf.zeros(inputs.mesh,
                                          ids_shape,
                                          dtype=tf.int32)
            return self.decoder.sample_autoregressive(
                partial_sequences,
                temperature=temperature,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                shared_params=shared_params,
                has_partial_sequences=False)
        else:
            if temperature != 0:
                raise ValueError(
                    "don't know how to beam search with nonzero temperature")
            # beam search
            beam_dim = mtf.Dimension("beam", beam_size)
            batch_dims = inputs.shape[:-1]
            length_dim = inputs.shape[-1]
            ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim])
            partial_sequences = mtf.zeros(inputs.mesh,
                                          ids_shape,
                                          dtype=tf.int32)
            input_length = mtf.reduce_sum(mtf.to_float(
                mtf.cast(inputs, tf.bool)),
                                          reduced_dim=length_dim)
            max_input_length = mtf.reduce_max(input_length)
            decode_length = mtf.cast(
                max_input_length * decode_length_multiplier +
                decode_length_constant, tf.int32)
            return self.decoder.beam_search(
                partial_sequences,
                decode_length,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                alpha=alpha,
                shared_params=shared_params)
示例#16
0
 def layer_prepostprocess_dropout(x):
     return mtf.dropout(
         x,
         keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
         noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
示例#17
0
 def testBroadcastOperation(self):
     broadcast_operation = mtf.BroadcastOperation(
         self.x, mtf.Shape([self.b_dim, self.c_dim, self.a_dim]))
     self.assertEqual(broadcast_operation.splittable_dims,
                      frozenset(["a", "b", "c"]))
     self.assertEqual(broadcast_operation.unsplittable_dims, frozenset())
示例#18
0
    def _layer_stack(self,
                     x,
                     layers,
                     encoder_output=None,
                     self_attention_mask=None,
                     encdec_attention_mask=None,
                     losses=None,
                     step_num=None,
                     encdec_tensors=None,
                     self_attention_k=None,
                     self_attention_v=None):
        """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      self_attention_k: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels] (incremental mode)
      self_attention_v: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels] (incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
        hparams = self._hparams
        is_incremental = (step_num is not None)

        def layer_prepostprocess_dropout(x):
            if is_incremental:
                return x
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        num_layers = len(layers)
        num_layer_norms = num_layers + 1
        layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
        layer_norm_combined_var = mtf.get_variable(
            x.mesh,
            "layer_norm_scale",
            mtf.Shape([layer_norms_dim, self.model_dim]),
            initializer=tf.ones_initializer(),
            activation_dtype=x.dtype)
        layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)

        def normalize(x):
            scale = layer_norm_vars.pop(0)
            variance = mtf.reduce_mean(mtf.square(x),
                                       reduced_dim=self.model_dim)
            return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

        if is_incremental:
            new_self_attention_k = []
            new_self_attention_v = []

        for lnum, layer_type in enumerate(layers):
            with tf.variable_scope("%s_%d" % (layer_type, lnum)):
                if layer_type == "att":
                    # Self attention layer
                    if is_incremental:
                        self_att_num = len(new_self_attention_k)
                        y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                            normalize(x),
                            prev_k=self_attention_k[self_att_num],
                            prev_v=self_attention_v[self_att_num],
                            step_num=step_num,
                            master_dtype=self.master_dtype,
                            slice_dtype=self.slice_dtype,
                            name="att")
                        new_self_attention_k.append(new_k)
                        new_self_attention_v.append(new_v)
                        x += y
                    else:
                        x += layer_prepostprocess_dropout(
                            mtf.layers.multihead_attention(
                                normalize(x),
                                None,
                                self_attention_mask,
                                self.kv_dim,
                                self.heads_dim,
                                dropout=hparams.attention_dropout,
                                dropout_broadcast_dims=[self.length_dim],
                                master_dtype=self.master_dtype,
                                slice_dtype=self.slice_dtype,
                                name="att"))
                elif layer_type == "enc_att":
                    # Encoder-Decoder attention layer
                    if is_incremental:
                        # Encoder-Decoder attention layer
                        q_var, o_var, k, v = encdec_tensors[lnum]
                        x += mtf.layers.multihead_encdec_attention_incremental(
                            normalize(x),
                            q_var,
                            o_var,
                            k,
                            v,
                            encdec_attention_mask,
                            name="enc_att")
                    else:
                        x += layer_prepostprocess_dropout(
                            mtf.layers.multihead_attention(
                                normalize(x),
                                encoder_output,
                                encdec_attention_mask,
                                self.kv_dim,
                                self.heads_dim,
                                dropout=hparams.attention_dropout,
                                dropout_broadcast_dims=[self.length_dim],
                                master_dtype=self.master_dtype,
                                slice_dtype=self.slice_dtype,
                                name="enc_att"))
                else:
                    if is_incremental:
                        # insert length dimension.
                        x_shape = x.shape
                        shape_with_length = mtf.Shape(
                            x_shape.dims[:-1] + [mtf.Dimension("length", 1)] +
                            x_shape.dims[-1:])
                        x = mtf.reshape(x, shape_with_length)
                    # ffn layer
                    x += layer_prepostprocess_dropout(
                        self._feedforward_layer(normalize(x),
                                                layer_type,
                                                losses=losses))
                    if is_incremental:
                        # remove length dimension
                        x = mtf.reshape(x, x_shape)

        x = layer_prepostprocess_dropout(normalize(x))
        assert not layer_norm_vars
        if is_incremental:
            return x, new_self_attention_k, new_self_attention_v
        else:
            return x
示例#19
0
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters(
        (mtf.Dimension("x", 5), ),
        (("x", 5), ),
    )
    def testConvertToDimension(self, inputs):
        dimension = mtf.convert_to_dimension(inputs)
        self.assertEqual(dimension.name, "x")
        self.assertEqual(dimension.size, 5)

    def testConvertToDimensionGenericInputs(self):
        dimension = mtf.convert_to_dimension(None)
        self.assertEqual(dimension, None)
        with self.assertRaises(TypeError):
            mtf.convert_to_dimension(5)

    @parameterized.parameters(
        (mtf.Shape([mtf.Dimension("x", 4),
                    mtf.Dimension("y", 8)]), ),
        ("x:4;y:8", ),
        ("x:4.y:8", ),
        ("x:4 y:8", ),
        ("x:4,y:8", ),
    )
    def testConvertToShape(self, inputs):
        shape = mtf.convert_to_shape(inputs)
        self.assertEqual(
            shape, mtf.Shape([mtf.Dimension("x", 4),
                              mtf.Dimension("y", 8)]))

    def testConvertToShapeGenericInputs(self):
        shape = mtf.convert_to_shape([])
        self.assertEqual(shape.dims, [])
        shape = mtf.convert_to_shape(None)
        self.assertEqual(shape, None)
        with self.assertRaises(ValueError):
            mtf.convert_to_shape("x;4")

    @parameterized.parameters(
        (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ),
        ("d_ff:model;heads:model", ),
        ("d_ff:model.heads:model", ),
        ("d_ff:model heads:model", ),
        ("d_ff:model,heads:model", ),
        ([("d_ff", "model"), ("heads", "model")], ),
    )
    def testConvertToLayoutRules(self, inputs):
        layout_rules = mtf.convert_to_layout_rules(inputs)
        self.assertEqual(
            layout_rules._pairs,
            mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

    def testConvertToLayoutRulesGenericInputs(self):
        with self.assertRaises(ValueError):
            mtf.convert_to_layout_rules("d_ff;heads")

    def testTensorLayout(self):
        tensor_layout = mtf.TensorLayout([0, 2, 1])
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, ))
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
        tensor_layout = mtf.TensorLayout([None, 0])
        self.assertFalse(tensor_layout.is_fully_replicated)
        tensor_layout = mtf.TensorLayout([None, None, None])
        self.assertTrue(tensor_layout.is_fully_replicated)

    def testGraph(self):
        graph = mtf.Graph()
        self.assertEmpty(graph.operations)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        mesh = mtf.Mesh(graph, "mesh_test")
        _ = mtf.import_tf_tensor(mesh,
                                 tf_tensor=tf.constant(0.),
                                 shape=mtf.Shape([]))
        self.assertLen(graph.operations, 1)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
        self.assertLen(graph.operations, 2)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 1)
        _ = mtf.get_variable(mesh,
                             "variable_1",
                             mtf.Shape([]),
                             trainable=False)
        self.assertLen(graph.operations, 3)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 2)

    def testGraphNames(self):
        # Standard Usage.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Edge cases, the user may choose the name "a_1".
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a_1"), "a_1_1")

        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a_1"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Case insensitive.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("A"), "A_1")

    @test_util.run_in_graph_and_eager_modes()
    def testLowering(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        inputs = tf.constant(0.)
        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          tf_tensor=inputs,
                                          shape=mtf.Shape([]))
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        outputs = lowering.export_to_tf_tensor(mtf_inputs)
        inputs_value, outputs_value = self.evaluate([inputs, outputs])
        self.assertEqual(inputs_value, outputs_value)

        # Check that methods run without error.
        _ = lowering.copy_masters_to_slices()
        _ = lowering.copy_slices_to_masters()

    def testMesh(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        self.assertEqual(mesh.graph, graph)

    def testMeshImpl(self):
        shape = mtf.Shape(
            [mtf.Dimension("batch", 4),
             mtf.Dimension("model", 8)])
        layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"),
                                        ("heads", "model")])
        mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
        self.assertEqual(mesh_impl.shape, shape)
        self.assertLen(shape, mesh_impl.ndims)
        self.assertEqual(mesh_impl.layout_rules, layout_rules)
        self.assertEqual(mesh_impl.size, shape.size)
        self.assertTrue(mesh_impl.supports_control_dependencies)

        batch = mtf.Dimension("batch", 128)
        length = mtf.Dimension("length", 500)
        d_ff = mtf.Dimension("d_ff", 2048)
        heads = mtf.Dimension("heads", 8)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
        self.assertEqual(
            mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
            mtf.TensorLayout([0, None, 1]))

    @parameterized.parameters(
        {
            "pool_fn": np.mean,
            "pool_fn_mtf": mtf.reduce_mean
        }, {
            "pool_fn": np.max,
            "pool_fn_mtf": mtf.reduce_max
        }, {
            "pool_fn": np.min,
            "pool_fn_mtf": mtf.reduce_min
        })
    def testPoolTensor1d(self, pool_fn, pool_fn_mtf):
        converter = mtf_test_utils.NumpyConverter()
        pool_size = 2
        x = np.random.randn(2, 3, 4, 5)
        expected = np.empty(shape=[2, 3, 2, 5])
        expected[:, :, 0, :] = pool_fn(x[:, :, 0:2, :], axis=2)
        expected[:, :, 1, :] = pool_fn(x[:, :, 2:4, :], axis=2)

        x_mtf = converter.convert_np_array_to_mtf_tensor(x, dtype=tf.float32)
        pooled_mtf = mtf.pool_tensor_1d(x_mtf,
                                        pool_dim=x_mtf.shape.dims[2],
                                        reduce_fn=pool_fn_mtf,
                                        pool_size=pool_size)
        actual = converter.convert_mtf_tensor_to_np_array(pooled_mtf)
        self.assertAllClose(expected, actual)

    @parameterized.parameters({"pool_size": 2}, {"pool_size": 3})
    def testStrideTensor1d(self, pool_size):
        converter = mtf_test_utils.NumpyConverter()
        x = np.random.randint(0, 100, size=[2, 3, 6, 5])
        x_mtf = converter.convert_np_array_to_mtf_tensor(x)
        expected = x[:, :, range(0, x.shape[2], pool_size), :]
        strided_mtf = mtf.stride_tensor_1d(x_mtf,
                                           pool_dim=x_mtf.shape.dims[2],
                                           pool_size=pool_size)
        actual = converter.convert_mtf_tensor_to_np_array(strided_mtf)
        self.assertAllEqual(expected, actual)

    def testReduceFirst(self):
        converter = mtf_test_utils.NumpyConverter()
        x = np.random.randint(0, 100, size=[2, 3, 6, 5])
        x_mtf = converter.convert_np_array_to_mtf_tensor(x)
        expected = x[:, :, 0, :]
        reduced_mtf = mtf.reduce_first(x_mtf, reduced_dim=x_mtf.shape.dims[2])
        actual = converter.convert_mtf_tensor_to_np_array(reduced_mtf)
        self.assertAllEqual(expected, actual)
示例#20
0
    def _sample(self, features, mesh):
        hparams = self._hparams
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "encdec":
            inputs = features["inputs"]
            while len(inputs.shape.as_list()) > 2:
                inputs = tf.squeeze(inputs, axis=2)
            actual_batch_size = tf.shape(inputs)[0]
            actual_length = tf.shape(inputs)[1]
            inputs = tf.pad(inputs,
                            [[0, hparams.batch_size - actual_batch_size],
                             [0, hparams.max_length - actual_length]])
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.reshape(positional_embedding_var,
                             mtf.Shape([self.length_dim, self.model_dim])))
            encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_attention_mask)
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)
            encdec_tensors = []
            for layer_num, layer_type in enumerate(hparams.decoder_layers):
                if layer_type == "enc_att":
                    with tf.variable_scope("decoder/enc_att_%d/enc_att" %
                                           layer_num):
                        q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                            mesh, self.heads_dim, self.model_dim, self.kv_dim,
                            self.master_dtype, self.slice_dtype,
                            self.activation_dtype)
                        k = mtf.einsum([encoder_output, k_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                        v = mtf.einsum([encoder_output, v_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                    encdec_tensors.append((q_var, o_var, k, v))
                else:
                    encdec_tensors.append(None)
            partial_targets = None
        elif hparams.transformer_type == "decoder":
            encdec_tensors = None
            encoder_output = None
            encoder_attention_mask = None
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs", None)
            if partial_targets is None:
                partial_targets = features.get("targets", None)
            if partial_targets is not None:
                partial_targets = common_layers.expand_squeeze_to_nd(
                    partial_targets, 2)
                partial_targets = tf.to_int32(partial_targets)
                partial_targets_batch = tf.shape(partial_targets)[0]
                partial_targets_length = tf.shape(partial_targets)[1]
                partial_targets = tf.pad(
                    partial_targets,
                    [[0, hparams.batch_size - partial_targets_batch],
                     [0, hparams.max_length - partial_targets_length]])
                partial_targets = self._import_to_batch_by_length(
                    partial_targets, "partial_targets", mesh, hparams)
        else:
            raise ValueError("hparams.model_type = %s not yet supported" %
                             hparams.transformer_type)

        if hparams.beam_size == 1:
            ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
            kv_shape = mtf.Shape(
                self.batch_dims +
                [self.heads_dim, self.memory_length_dim, self.kv_dim])
        else:
            beam_dim = mtf.Dimension("beam", hparams.beam_size)
            ids_shape = mtf.Shape(self.batch_dims +
                                  [beam_dim, self.length_dim])
            kv_shape = mtf.Shape(self.batch_dims + [
                beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim
            ])

        initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
        num_self_att = len([l for l in hparams.decoder_layers if l == "att"])
        initial_kv_states = (
            [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] *
            (2 * num_self_att))

        def logits_fn(step_num, ids, states):
            """Produce logits for this step, and new states."""
            self_attention_k = states[:num_self_att]
            self_attention_v = states[num_self_att:]
            ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
            x = (mtf.gather(targets_embedding_var, ids_this_step,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, step_num,
                            self.max_length_dim))
            with tf.variable_scope("decoder"):
                x, new_self_attention_k, new_self_attention_v = (
                    self._layer_stack(
                        x,
                        hparams.decoder_layers,
                        encdec_attention_mask=encoder_attention_mask,
                        step_num=step_num,
                        encdec_tensors=encdec_tensors,
                        self_attention_k=self_attention_k,
                        self_attention_v=self_attention_v))
            logits = mtf.matmul(x, softmax_var)
            return logits, new_self_attention_k + new_self_attention_v

        if hparams.beam_size == 1:
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            return mtf.beam_search.greedy_decode(
                logits_fn,
                initial_ids,
                temperature=temperature,
                initial_states=initial_kv_states,
                forced_ids=partial_targets,
                use_tpu=hparams.use_tpu)
        else:
            if hparams.transformer_type == "encdec":
                input_length = mtf.reduce_sum(mtf.to_float(
                    mtf.cast(inputs, tf.bool)),
                                              reduced_dim=self.length_dim)
                max_input_length = mtf.reduce_max(input_length)
                decode_length = mtf.cast(
                    max_input_length * hparams.decode_length_multiplier +
                    hparams.decode_length_constant, tf.int32)
            else:
                decode_length = None
            beams, unused_scores = mtf.beam_search.beam_search(
                logits_fn,
                initial_ids,
                hparams.alpha,
                states=initial_kv_states,
                decode_length=decode_length,
                use_tpu=hparams.use_tpu,
                dtype=self.activation_dtype)
            return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32),
                              beam_dim)
示例#21
0
 def _import_to_batch_by_length(self, x, name, mesh):
     mtf_shape = mtf.Shape(self.batch_dims + [self.length_dim])
     x = tf.reshape(x, mtf_shape.to_integer_list)
     return mtf.import_fully_replicated(mesh, x, mtf_shape, name=name)
示例#22
0
    def _decoder_layer_stack_incremental(self,
                                         x,
                                         step_num,
                                         encdec_tensors,
                                         self_attention_k,
                                         self_attention_v,
                                         encdec_attention_mask=None):
        """Decoder layer stack during inference.

    We are processing only one position at a time.

    The self-attention keys and values have already been computed for
    previous positions.  In addition to the decoder output, we need to
    produce the updated self-attention keys and values.

    If there is an encoder, then additional Tensors are supplied in
    encdec_tensors, which give us the keys and values for encoder-decoder
    attention as well as the weight matrices q_var and o_var.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, model_dim]
      step_num: an mtf integer Scalar
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v)
      self_attention_k: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      self_attention_v: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.

    Returns:
      y: a mtf.Tensor with shape [<batch_dims>, model_dim]
      new_self_attention_k: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_k
      new_self_attention_v: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_v

    Raises:
      ValueError: if hparams make no sense
    """
        hparams = self._hparams
        num_layers = hparams.num_decoder_layers
        num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1
        layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
        layer_norm_combined_var = mtf.get_variable(
            x.mesh,
            "layer_norm_scale",
            mtf.Shape([layer_norms_dim, self.model_dim]),
            initializer=tf.ones_initializer(),
            activation_dtype=x.dtype)
        layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)

        def normalize(x):
            scale = layer_norm_vars.pop(0)
            variance = mtf.reduce_mean(mtf.square(x),
                                       reduced_dim=self.model_dim)
            return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

        new_self_attention_k = []
        new_self_attention_v = []
        for layer in xrange(num_layers):
            with tf.variable_scope("layer_%d" % layer):
                # Self attention layer
                y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                    normalize(x),
                    prev_k=self_attention_k[layer],
                    prev_v=self_attention_v[layer],
                    step_num=step_num,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att")
                new_self_attention_k.append(new_k)
                new_self_attention_v.append(new_v)
                x += y
                if encdec_tensors is not None:
                    # Encoder-Decoder attention layer
                    q_var, o_var, k, v = encdec_tensors[layer]
                    x += mtf.layers.multihead_encdec_attention_incremental(
                        normalize(x),
                        q_var,
                        o_var,
                        k,
                        v,
                        encdec_attention_mask,
                        name="enc_att")
                # ffn layer
                x += self._feedforward_layer(normalize(x), layer)
        x = normalize(x)
        assert not layer_norm_vars
        return x, new_self_attention_k, new_self_attention_v
def gradient_based_subword_tokenization(x,
                                        length_dim,
                                        max_subword_length=4,
                                        downsample=None,
                                        use_offsets=False,
                                        consider_chars_as_blocks=False,
                                        use_block_pos_embedding=False,
                                        share_block_kernel=False,
                                        memory_embeddings=0,
                                        context=None,
                                        block_mixing_mode=None,
                                        activation="softmax",
                                        downsample_function="mean"):
    """Implements GBSWT from Charformer.

  Args:
    x: a Tensor containing length_dim
    length_dim: a Dimension
    max_subword_length: integer
    downsample: integer.
    use_offsets: boolean.
    consider_chars_as_blocks: boolean.
    use_block_pos_embedding: boolean.
    share_block_kernel: boolean.
    memory_embeddings: integer.
    context: Context.
    block_mixing_mode: Str for block mixing.
    activation: Str for block ranking.
    downsample_function: Str, supports mean/linformer for now.

  Returns:
    a Tensor with the same shape as x.

  Raises:
    ValueError: if channels or depth don't match.
  """
    # don't use this for now.
    del max_subword_length
    del memory_embeddings
    all_blocks = []
    all_scores = []
    tf.logging.info("GSW block layer")

    def _tile(x, n, tile_dim):
        # Simple tile function in MTF.
        return mtf.concat([x] * n, tile_dim.name)

    def _repeat(x, n, repeat_dim):
        # repeat function in MTF
        tmp_dim = mtf.Dimension("tmp", 1)
        expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
        x = mtf.reshape(x, expand_shape)
        x = _tile(x, n, tmp_dim)
        output_shape = []
        for dim in x.shape.dims:
            if dim.name == "tmp":
                continue
            if dim.name == repeat_dim.name:
                dim = mtf.Dimension(dim.name, dim.size * n)
            output_shape.append(dim)
        output_shape = mtf.Shape(output_shape)
        x = mtf.reshape(x, output_shape)
        return x

    def _combined_dim(dims):
        return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)

    # compute all subword blocks
    # TODO(yitay): handle offsets to get all blocks
    if activation == "sigtanh":
        # one score for sigmoid
        tmp_dim = mtf.Dimension("block_score", 2)
    else:
        tmp_dim = mtf.Dimension("block_score", 1)

    model_dim = x.shape[-1]
    subword_blocks_width = [2, 3, 4]

    if consider_chars_as_blocks:
        subword_blocks_width += [1]

    if share_block_kernel:
        block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
        block_kernel = mtf.get_variable(x.mesh,
                                        "block_kernel",
                                        block_kernel_shape,
                                        initializer=None,
                                        dtype=context.variable_dtype)
    else:
        block_kernel = None

    for subword_len in subword_blocks_width:
        if use_block_pos_embedding:
            # this is turn off by default. It is meant to support cases like
            # parameterized pooling or other features.
            block_len_dim = mtf.Dimension(length_dim.name, subword_len)
            # TODO(vqtran): Consider other positional embeddings.
            block_pos_emb = sinusoid_positional_embedding_weights(
                context.mesh, block_len_dim, x.shape[-1],
                context.variable_dtype.activation_dtype)
            block_pos_emb = _repeat(
                block_pos_emb, math.ceil(length_dim.size / float(subword_len)),
                block_len_dim)
        if use_offsets:
            offset_space = subword_len
        else:
            offset_space = 1
        for offsets in range(offset_space):
            if offsets > 0:
                xoff = mtf.shift(x, offsets, length_dim, wrap=False)
                if use_block_pos_embedding:
                    block_pos_emb = mtf.shift(block_pos_emb,
                                              offsets,
                                              block_pos_emb.shape[-2],
                                              wrap=False)
            else:
                xoff = x
            tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
            if length_dim.size % subword_len != 0:
                tf.logging.info("Not divisible by length")
                # add extra padding tokens
                pad_amt = int(subword_len) - int(length_dim.size % subword_len)
                kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
            else:
                kp = xoff

            if use_block_pos_embedding:
                kp += block_pos_emb

            bx = mtf.pool_tensor_1d(
                kp,
                pool_dim=kp.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(subword_len))
            block_score = mtf.layers.dense(bx, [tmp_dim],
                                           use_bias=False,
                                           name="bx",
                                           reduced_dims=[model_dim],
                                           variable_dtype=None,
                                           kernel_weights=block_kernel)

            expand_bx = _repeat(bx, subword_len, length_dim)
            expand_scores = _repeat(block_score, subword_len, length_dim)
            if offsets > 0:
                # add offset.
                expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [offsets, 0],
                                        length_dim.name)
            new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
            if new_len.size < length_dim.size:
                pad_amt = new_len.size - length_dim.size
                expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [0, pad_amt],
                                        length_dim.name)
            elif new_len.size > length_dim.size:
                expand_bx = mtf.slice(expand_bx, 0, length_dim.size,
                                      length_dim.name)
                expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
                                          length_dim.name)

            new_tmp_dim = mtf.Dimension("extra_dim", 1)
            expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
            expand_scores_shape = mtf.Shape(expand_scores.shape.dims +
                                            [new_tmp_dim])
            expand_bx = mtf.reshape(expand_bx, expand_shape)
            expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
            all_blocks.append(expand_bx)
            all_scores.append(expand_scores)

    all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
    all_scores = mtf.concat(all_scores, new_tmp_dim.name)
    tf.logging.info(all_blocks)
    new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
    combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
    block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
    block_net = mtf.reshape(all_scores, block_net_shape)

    if block_mixing_mode == "score_attention":
        tf.logging.info("Using score attention")
        att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
        tf.logging.info(block_net)
        att = mtf.softmax(att, reduced_dim=att.shape[-1])
        block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
        tf.logging.info(block_net)

    if activation == "softmax":
        block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
    elif activation == "tanh":
        tf.logging.info("Using tanh")
        block_net = mtf.tanh(block_net)

    all_blocks = block_net * all_blocks
    all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
    output = all_blocks

    if downsample:
        output_length = output.shape.get_dim_by_name("length")
        if output_length.size % int(downsample) != 0:
            pad_amt = int(downsample) - int(
                output_length.size % int(downsample))
            output = mtf.pad(output, [0, pad_amt], output_length.name)
        if downsample_function == "mean":
            output = mtf.pool_tensor_1d(
                output,
                pool_dim=output.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(downsample))
        else:
            raise ValueError("Downsampling function not implemeneted.")

    return output
    def get_timing_signal_1d(self,
                             context,
                             length,
                             channels,
                             min_timescale=1.0,
                             max_timescale=1.0e4,
                             start_index=0):
        """Gets a bunch of sinusoids of different frequencies.

    Each channel of the input Tensor is incremented by a sinusoid of a different
    frequency and phase.

    This allows attention to learn to use absolute and relative positions.
    Timing signals should be added to some precursors of both the query and the
    memory inputs to attention.

    The use of relative position is possible because sin(x+y) and cos(x+y) can
    be expressed in terms of y, sin(x) and cos(x).

    In particular, we use a geometric sequence of timescales starting with
    min_timescale and ending with max_timescale.  The number of different
    timescales is equal to channels / 2. For each timescale, we
    generate the two sinusoidal signals sin(timestep/timescale) and
    cos(timestep/timescale).  All of these sinusoids are concatenated in
    the channels dimension.

    Args:
      context: mtf context.
      length: a mtf.Dimension, length of timing signal sequence.
      channels: a mtf.Dimension, size of timing embeddings to create.
      The number of different timescales is equal to channels / 2.
      min_timescale: a float
      max_timescale: a float
      start_index: index of first position

    Returns:
      a Tensor of timing signals [1, length, channels]
    """

        position = context.get_position() + start_index
        num_timescales = mtf.constant(context.mesh, channels.size // 2)
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            mtf.maximum(num_timescales - 1, 1))
        channel_dim_name = channels.name
        inv_timescales = (min_timescale * mtf.exp(
            mtf.mtf_range(context.mesh,
                          mtf.Dimension(channel_dim_name, channels.size // 2),
                          context.activation_dtype) * -log_timescale_increment)
                          )

        scaled_time = position * inv_timescales
        # Please note that this slightly differs from the published paper.
        # See a discussion here:
        # https://github.com/tensorflow/tensor2tensor/pull/177
        #    concat_dim_name = scaled_time.shape.dimension_names[1]
        concat_dim_name = channels.name
        signal = mtf.concat(
            [mtf.sin(scaled_time), mtf.cos(scaled_time)],
            concat_dim_name=concat_dim_name)

        if channels.size % 2 != 0:
            raise NotImplementedError("Odd channel size not implemented.")
        new_dims = [mtf.Dimension("expanded", 1)
                    ] + length.shape.dims + channels.shape.dim
        signal = mtf.reshape(signal, mtf.Shape(new_dims))
        return signal
示例#25
0
文件: attention.py 项目: jahau/mesh
    def __init__(self,
                 mesh,
                 query_input_dim,
                 memory_input_dim,
                 output_dim,
                 key_dim,
                 value_dim,
                 query_heads_dims,
                 memory_heads_dims,
                 variable_dtype,
                 shared_kv=False,
                 combine_dims=True,
                 ensemble_dim=None):
        """Create attention parameters.

    combine_dims is a hack for faster execution.  The heads and key/value
    dimensions are combined in the variables and the computation.  The hack
    would not be necessary if XLA optimized einsum properly.

    Args:
      mesh: a Mesh
      query_input_dim: a Dimension
      memory_input_dim: a Dimension
      output_dim: a Dimension
      key_dim: a Dimension
      value_dim: a Dimension
      query_heads_dims: a list of Dimension
      memory_heads_dims: a list of Dimension
      variable_dtype: a mtf.VariableDType
      shared_kv: a boolean
      combine_dims: a boolean
      ensemble_dim: an optional Dimension
    """
        if shared_kv and key_dim != value_dim:
            raise ValueError("shared_kv requires key_dim == value_dim")
        self.query_input_dim = query_input_dim
        self.memory_input_dim = memory_input_dim
        self.output_dim = output_dim
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.query_heads_dims = query_heads_dims or []
        self.memory_heads_dims = memory_heads_dims or []
        self.shared_kv = shared_kv
        self.combine_dims = combine_dims
        if combine_dims:
            q_shape = [query_input_dim, _combined_dim(self.q_dims)]
            k_shape = [memory_input_dim, _combined_dim(self.k_dims)]
            v_shape = [memory_input_dim, _combined_dim(self.v_dims)]
            o_shape = [_combined_dim(self.o_dims), output_dim]
        else:
            q_shape = [query_input_dim] + self.q_dims
            k_shape = [memory_input_dim] + self.k_dims
            v_shape = [memory_input_dim] + self.v_dims
            o_shape = self.o_dims + [output_dim]
        q_init = tf.random_normal_initializer(stddev=(query_input_dim.size *
                                                      key_dim.size)**-0.5)
        kv_init = tf.random_normal_initializer(
            stddev=memory_input_dim.size**-0.5)
        o_init = tf.random_normal_initializer(
            stddev=mtf.Shape(self.query_heads_dims + [value_dim]).size**-0.5)
        if ensemble_dim:
            q_shape = [ensemble_dim] + q_shape
            k_shape = [ensemble_dim] + k_shape
            v_shape = [ensemble_dim] + v_shape
            o_shape = [ensemble_dim] + o_shape
        self.wq = mtf.get_variable(mesh,
                                   "q",
                                   q_shape,
                                   initializer=q_init,
                                   dtype=variable_dtype)
        if shared_kv:
            self.wkv = mtf.get_variable(mesh,
                                        "kv",
                                        k_shape,
                                        initializer=kv_init,
                                        dtype=variable_dtype)
        else:
            self.wk = mtf.get_variable(mesh,
                                       "k",
                                       k_shape,
                                       initializer=kv_init,
                                       dtype=variable_dtype)
            self.wv = mtf.get_variable(mesh,
                                       "v",
                                       v_shape,
                                       initializer=kv_init,
                                       dtype=variable_dtype)
        self.wo = mtf.get_variable(mesh,
                                   "o",
                                   o_shape,
                                   initializer=o_init,
                                   dtype=variable_dtype)
示例#26
0
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
    """Builds the UNet model graph, train op and eval metrics.

  Args:
    mesh: a MeshTensorflow.mesh object.
    mesh_impl: a mesh implementation, such as SimdMeshImpl and
      PlacementMeshImpl.
    dataset_str: a string of either train or eval. This is used for batch_norm.
    images: a laid out Tensor with shape [batch, x, y, num_channels]
      or [batch, x, y, z, num_channels].
    labels: a laid out Tensor with shape [batch, x, y, num_classes]
      or [batch, x, y, z, num_classes].

  Returns:
    Prediction and loss.
  """

    is_training = (dataset_str == 'train')
    if dataset_str == 'train':
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
    else:
        assert dataset_str == 'eval'
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
    image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
    image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
    image_sx_dim = mtf.Dimension('image_sx_block',
                                 FLAGS.ct_resolution // FLAGS.image_nx_block)
    image_sy_dim = mtf.Dimension('image_sy_block',
                                 FLAGS.ct_resolution // FLAGS.image_ny_block)
    image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
    image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
    label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
    mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)

    mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
    variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)

    # Import input features.
    x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images),
                                   mtf_images_shape)
    x = mtf.cast(x, mtf_dtype)

    # Import ground truth labels.
    t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels),
                                   mtf_labels_shape)
    t = mtf.cast(t, mtf_dtype)

    # Transpose the blocks.
    if FLAGS.sampled_2d_slices:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            label_c_dim
        ])
    else:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, label_c_dim
        ])

    # Network.
    levels = []
    all_bn_update_ops = []
    # add levels with convolution or down-sampling
    for depth in range(FLAGS.network_depth):
        for n_conv in range(FLAGS.n_conv_per_block):
            if depth == 0 and n_conv == 0:
                # no dropout in 1st layer.
                dropout_keep_p = 1.0
            else:
                dropout_keep_p = FLAGS.dropout_keep_p
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_down_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)
        levels.append(x)

        if depth < FLAGS.network_depth - 1:
            if FLAGS.sampled_2d_slices:
                x = mtf.layers.max_pool2d(x, ksize=(2, 2))
            else:
                x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))

    # add levels with up-convolution or up-sampling
    for depth in range(FLAGS.network_depth - 1)[::-1]:
        x = deconv_with_spatial_partition(
            x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
            FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
            'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
            variable_dtype, 'deconv_{}_0'.format(depth))
        x = mtf.concat([x, levels[depth]],
                       concat_dim_name='conv_{}_{}'.format(
                           depth, FLAGS.n_conv_per_block - 1))

        for n_conv in range(FLAGS.n_conv_per_block):
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_up_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)

    # no dropout in the final layer.
    if FLAGS.sampled_2d_slices:
        y = mtf.layers.conv2d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1),
            strides=(1, 1),
            padding='SAME',
            h_blocks_dim=image_nx_dim,
            w_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )
    else:
        y = mtf.layers.conv3d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1, 1),
            strides=(1, 1, 1),
            padding='SAME',
            d_blocks_dim=image_nx_dim,
            h_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )

    # use mtf.constant to make sure there is no CPU-side constants.
    def scalar(v, dtype):
        return mtf.constant(mesh, v, shape=[], dtype=dtype)

    argmax_t = mtf.argmax(t, label_c_dim)
    liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
    lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)

    argmax_y = mtf.argmax(y, label_c_dim)
    lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)

    # summary of class ratios.
    lesion_pred_ratio = mtf.reduce_mean(lesion_y)
    lesion_label_ratio = mtf.reduce_mean(lesion_t)

    # summary of accuracy.
    accuracy = mtf.reduce_mean(
        mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))

    # Cross-entropy loss. Up-weight the liver region.
    pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(
        y, t, label_c_dim)
    pixel_weight = scalar(1, mtf_dtype) + \
        liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
        lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
                          mtf_dtype)
    loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)

    # Dice loss
    y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
    lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
                                 reduced_dim=mtf.Dimension('label_c', 1))
    prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
                                    output_shape=mtf.Shape([batch_dim]))
    prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
                                   output_shape=mtf.Shape([batch_dim]))
    loss_dice_per_case = mtf.reduce_mean(
        scalar(-2, mtf_dtype) * prob_intersect /
        (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
    loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(
        prob_intersect) / (mtf.reduce_sum(prob_area_sum) +
                           scalar(FLAGS.dice_epsilon, mtf_dtype))

    loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(
        0.5, mtf_dtype)

    loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
        1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen

    intersect = mtf.reduce_sum(lesion_y * lesion_t,
                               output_shape=mtf.Shape([batch_dim]))
    area_sum = mtf.reduce_sum(lesion_y + lesion_t,
                              output_shape=mtf.Shape([batch_dim]))
    # summary of dice.
    dice_per_case = mtf.reduce_mean(
        scalar(2, mtf_dtype) * intersect /
        (area_sum + scalar(0.000001, mtf_dtype)))
    dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
        mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))

    eval_metrics = {
        'lesion_pred_ratio': lesion_pred_ratio,
        'lesion_label_ratio': lesion_label_ratio,
        'accuracy_of_all_classes': accuracy,
        'lesion_dice_per_case': dice_per_case,
        'lesion_dice_global': dice_global,
        'loss_xen': loss_xen,
        'loss_dice': loss_dice,
        'loss_dice_per_case': loss_dice_per_case,
        'loss_dice_global': loss_dice_global,
    }

    if FLAGS.sampled_2d_slices:
        y_prob_downsampled = mtf.layers.avg_pool2d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 2)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool2d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 2)
    else:
        y_prob_downsampled = mtf.layers.avg_pool3d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 3)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool3d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 3)

    liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
    lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
    preds = [
        mtf.reduce_sum(liver_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1)),
        mtf.reduce_sum(lesion_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1))
    ]

    if FLAGS.output_ground_truth:
        preds.append(
            mtf.reduce_sum(lesion_gt_downsampled,
                           reduced_dim=mtf.Dimension('label_c', 1)))

    preds.extend([intersect, area_sum])

    return preds, loss, eval_metrics, all_bn_update_ops
示例#27
0
def mnist_model(image, labels, mesh):
  """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  row_blocks_dim = mtf.Dimension("row_blocks", 4)
  col_blocks_dim = mtf.Dimension("col_blocks", 4)
  rows_dim = mtf.Dimension("rows_size", 7)
  cols_dim = mtf.Dimension("cols_size", 7)

  classes_dim = mtf.Dimension("classes", 10)
  one_channel_dim = mtf.Dimension("one_channel", 1)

  x = mtf.import_tf_tensor(
      mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
      mtf.Shape(
          [batch_dim, row_blocks_dim, rows_dim,
           col_blocks_dim, cols_dim, one_channel_dim]))
  x = mtf.transpose(x, [
      batch_dim, row_blocks_dim, col_blocks_dim,
      rows_dim, cols_dim, one_channel_dim])

  # add some convolutional layers to demonstrate that convolution works.
  filters1_dim = mtf.Dimension("filters1", 16)
  filters2_dim = mtf.Dimension("filters2", 16)
  f1 = mtf.relu(mtf.layers.conv2d_with_blocks(
      x, filters1_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv0"))
  f2 = mtf.relu(mtf.layers.conv2d_with_blocks(
      f1, filters2_dim, filter_size=[9, 9], strides=[1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim, name="conv1"))
  x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

  # add some fully-connected dense layers.
  hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
  hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

  h1 = mtf.layers.dense(
      x, hidden_dim1,
      reduced_dims=x.shape.dims[-4:],
      activation=mtf.relu, name="hidden1")
  h2 = mtf.layers.dense(
      h1, hidden_dim2,
      activation=mtf.relu, name="hidden2")
  logits = mtf.layers.dense(h2, classes_dim, name="logits")
  if labels is None:
    loss = None
  else:
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, mtf.one_hot(labels, classes_dim), classes_dim)
    loss = mtf.reduce_mean(loss)
  return logits, loss
示例#28
0
文件: ops_test.py 项目: Mrfixsit/mesh
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters(
        (mtf.Dimension("x", 5), ),
        (("x", 5), ),
    )
    def testConvertToDimension(self, inputs):
        dimension = mtf.convert_to_dimension(inputs)
        self.assertEqual(dimension.name, "x")
        self.assertEqual(dimension.size, 5)

    def testConvertToDimensionGenericInputs(self):
        dimension = mtf.convert_to_dimension(None)
        self.assertEqual(dimension, None)
        with self.assertRaises(TypeError):
            mtf.convert_to_dimension(5)

    @parameterized.parameters(
        (mtf.Shape([mtf.Dimension("x", 4),
                    mtf.Dimension("y", 8)]), ),
        ("x:4;y:8", ),
        ("x:4.y:8", ),
        ("x:4 y:8", ),
        ("x:4,y:8", ),
    )
    def testConvertToShape(self, inputs):
        shape = mtf.convert_to_shape(inputs)
        self.assertEqual(
            shape, mtf.Shape([mtf.Dimension("x", 4),
                              mtf.Dimension("y", 8)]))

    def testConvertToShapeGenericInputs(self):
        shape = mtf.convert_to_shape([])
        self.assertEqual(shape.dims, [])
        shape = mtf.convert_to_shape(None)
        self.assertEqual(shape, None)
        with self.assertRaises(ValueError):
            mtf.convert_to_shape("x;4")

    @parameterized.parameters(
        (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ),
        ("d_ff:model;heads:model", ),
        ("d_ff:model.heads:model", ),
        ("d_ff:model heads:model", ),
        ("d_ff:model,heads:model", ),
        ([("d_ff", "model"), ("heads", "model")], ),
    )
    def testConvertToLayoutRules(self, inputs):
        layout_rules = mtf.convert_to_layout_rules(inputs)
        self.assertEqual(
            layout_rules._pairs,
            mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

    def testConvertToLayoutRulesGenericInputs(self):
        with self.assertRaises(ValueError):
            mtf.convert_to_layout_rules("d_ff;heads")

    def testTensorLayout(self):
        tensor_layout = mtf.TensorLayout([0, 2, 1])
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, ))
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
        tensor_layout = mtf.TensorLayout([None, 0])
        self.assertFalse(tensor_layout.is_fully_replicated)
        tensor_layout = mtf.TensorLayout([None, None, None])
        self.assertTrue(tensor_layout.is_fully_replicated)

    def testGraph(self):
        graph = mtf.Graph()
        self.assertEmpty(graph.operations)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        mesh = mtf.Mesh(graph, "mesh_test")
        _ = mtf.import_tf_tensor(mesh,
                                 tf_tensor=tf.constant(0.),
                                 shape=mtf.Shape([]))
        self.assertLen(graph.operations, 1)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
        self.assertLen(graph.operations, 2)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 1)
        _ = mtf.get_variable(mesh,
                             "variable_1",
                             mtf.Shape([]),
                             trainable=False)
        self.assertLen(graph.operations, 3)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 2)

    def testGraphNames(self):
        # Standard Usage.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Edge cases, the user may choose the name "a_1".
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a_1"), "a_1_1")

        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a_1"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Case insensitive.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("A"), "A_1")

    @tf.contrib.eager.run_test_in_graph_and_eager_modes()
    def testLowering(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        inputs = tf.constant(0.)
        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          tf_tensor=inputs,
                                          shape=mtf.Shape([]))
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        outputs = lowering.export_to_tf_tensor(mtf_inputs)
        inputs_value, outputs_value = self.evaluate([inputs, outputs])
        self.assertEqual(inputs_value, outputs_value)

        # Check that methods run without error.
        _ = lowering.copy_masters_to_slices()
        _ = lowering.copy_slices_to_masters()

    def testMesh(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        self.assertEqual(mesh.graph, graph)

    def testMeshImpl(self):
        shape = mtf.Shape(
            [mtf.Dimension("batch", 4),
             mtf.Dimension("model", 8)])
        layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"),
                                        ("heads", "model")])
        mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
        self.assertEqual(mesh_impl.shape, shape)
        self.assertLen(shape, mesh_impl.ndims)
        self.assertEqual(mesh_impl.layout_rules, layout_rules)
        self.assertEqual(mesh_impl.size, shape.size)
        self.assertTrue(mesh_impl.supports_control_dependencies)

        batch = mtf.Dimension("batch", 128)
        length = mtf.Dimension("length", 500)
        d_ff = mtf.Dimension("d_ff", 2048)
        heads = mtf.Dimension("heads", 8)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
        self.assertEqual(
            mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
            mtf.TensorLayout([0, None, 1]))
示例#29
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.activation_type

        # We assume fixed vocab size for targets
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim,
                                                   self.length_dim]),
                                        name=name)

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([self.targets_vocab_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       self.targets_vocab_dim)

        # Add positional embeddings
        x += mtf.reshape(self.create_positional_emb_2d(targets),
                         [self.length_dim, self.model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf.layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    self.inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        if hparams.attention_type == "local1d_spatial":
            decoder_output = local_attention1d_spatial_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        elif hparams.attention_type == "local2d_spatial":
            decoder_output = local_attention2d_spatial_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        elif hparams.attention_type == "local1d":
            decoder_output = local_attention1d_masked_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        else:
            raise ValueError("Invalid attention type.")

        # Calculate the logits and loss.
        logits = mtf.layers.dense(decoder_output,
                                  self.outputs_vocab_dim,
                                  name="logits")
        # Need a reshape for logits
        logits = mtf.reshape(
            logits,
            mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim]))
        soft_targets = mtf.one_hot(targets,
                                   self.outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.outputs_vocab_dim)
        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l

        # Reshape logits to original target shape.
        logits = mtf.reshape(
            logits,
            mtf.Shape([
                batch_dim, self.rows_dim, self.orig_cols_dim,
                self.channels_dim, self.outputs_vocab_dim
            ]))

        return logits, loss
示例#30
0
def synthetic_attention(q,
                        k,
                        v,
                        memory_length_dim,
                        key_dim,
                        value_dim,
                        bias=None,
                        dropout_rate=0.0,
                        dropout_broadcast_dims=None,
                        extra_logit=None,
                        synthesize=True,
                        synthesize_mode="random_plus_alpha",
                        factorized_dim=16,
                        max_length=512,
                        context=None):
    """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743).

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    synthesize: flag to use synthetic attention or not
    synthesize_mode: which variant of synthesizer to use
    factorized_dim: factorized dim for synthesizers
    max_length: max length of input sequence
    context: context since we need context mode

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """

    if synthesize:
        num_heads = v.shape.get_dim_by_name("heads")
        tf.logging.info("Using synthesizer")
        if synthesize_mode == "random":
            tf.logging.info("Using Random Synthesizers")
            r_shape = mtf.Shape([
                mtf.Dimension("length", max_length),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", max_length)
            ])
            r = mtf.get_variable(context.mesh,
                                 "R",
                                 r_shape,
                                 initializer=None,
                                 dtype=context.variable_dtype)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                r = mtf.gather(r, context.position,
                               r.shape.get_dim_by_name("length"))
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, "length")
            logits = r
            r_shape = logits.shape
        elif synthesize_mode == "factorized":
            tf.logging.info("Using Factorized Random Synthesizers")
            k = factorized_dim
            r1_shape = mtf.Shape([
                mtf.Dimension("tmp", k),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r2_shape = mtf.Shape([
                mtf.Dimension("tmp", k),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r_shape = mtf.Shape([
                mtf.Dimension("length", 512),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r1 = mtf.get_variable(context.mesh,
                                  "R1",
                                  r1_shape,
                                  initializer=None,
                                  dtype=context.variable_dtype)
            r2 = mtf.get_variable(context.mesh,
                                  "R2",
                                  r2_shape,
                                  initializer=None,
                                  dtype=context.variable_dtype)
            r = mtf.einsum([r1, r2], r_shape)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                r = mtf.gather(r, context.position,
                               r.shape.get_dim_by_name("length"))
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, "length")
            logits = r
        elif synthesize_mode == "dense_minus":
            # Dense Synthesizer Model
            tmp_dim = mtf.Dimension("memory_length", max_length)
            logits = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                      use_bias=False,
                                      name="pi",
                                      reduced_dims=[key_dim],
                                      variable_dtype=None)
            logits = mtf.slice(logits, 0, memory_length_dim.size,
                               memory_length_dim.name)
            if context.mode == "incremental":
                pass
            else:
                length_dim = q.shape.get_dim_by_name("length")
                logits = mtf.slice(logits, 0, length_dim.size, "length")
        elif synthesize_mode == "random_plus_alpha" or \
            synthesize_mode == "random_plus":
            # Mixture Random Synthesizer with learnable Alpha
            tf.logging.info("Using Random Plus Alpha")
            logits = mtf.einsum([q, k], reduced_dims=[key_dim])
            num_heads = logits.shape.get_dim_by_name("heads")
            r_shape = mtf.Shape([
                mtf.Dimension("length", 512),
                mtf.Dimension("heads", num_heads.size),
                mtf.Dimension("memory_length", 512)
            ])
            r = mtf.get_variable(context.mesh,
                                 "R",
                                 r_shape,
                                 initializer=None,
                                 dtype=context.variable_dtype)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                r = mtf.gather(r, context.position,
                               r.shape.get_dim_by_name("length"))
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, length_dim.name)
            if "alpha" in synthesize_mode:
                alpha = mtf.get_variable(context.mesh,
                                         "alpha",
                                         mtf.Shape([mtf.Dimension("alpha",
                                                                  1)]),
                                         initializer=tf.zeros_initializer(),
                                         dtype=context.variable_dtype)
                alpha = mtf.sigmoid(alpha)
                logits = ((1 - alpha) * logits) + (alpha * r)
            else:
                logits = logits + r
        elif synthesize_mode == "dense_plus_alpha" or \
            synthesize_mode == "dense_plus":
            # Mixture Dense Synthesizer with learnable alpha
            tf.logging.info("Using Dense Plus Alpha Scaling")
            logits = mtf.einsum([q, k], reduced_dims=[key_dim])
            tmp_dim = mtf.Dimension("memory_length", 512)
            r = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                 use_bias=False,
                                 name="pi",
                                 reduced_dims=[key_dim],
                                 variable_dtype=None)
            r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
            if context.mode == "incremental":
                pass
            else:
                length_dim = q.shape.get_dim_by_name("length")
                r = mtf.slice(r, 0, length_dim.size, "length")
            if "alpha" in synthesize_mode:
                alpha = mtf.get_variable(context.mesh,
                                         "alpha",
                                         mtf.Shape([mtf.Dimension("alpha",
                                                                  1)]),
                                         initializer=tf.zeros_initializer(),
                                         dtype=context.variable_dtype)
                alpha = mtf.sigmoid(alpha)
                logits = ((1 - alpha) * logits) + (alpha * r)
            else:
                logits = logits + r
    if bias is not None:
        logits += bias

    weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
    if dropout_rate != 0.0:
        weights = mtf.dropout(weights,
                              1.0 - dropout_rate,
                              noise_shape=weights.shape -
                              dropout_broadcast_dims)

    if synthesize and "plus" not in synthesize_mode:
        if synthesize_mode == "dense_minus":
            outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim])
        else:
            outputs_shape = mtf.Shape(q.shape.dims[:-1] +
                                      [num_heads, value_dim])
    else:
        outputs_shape = q.shape - [key_dim] + value_dim

    outputs = mtf.einsum([weights, v], outputs_shape)
    return outputs