예제 #1
0
    def create_positional_emb_2d(self, targets):
        """Learned 2d positional embedding for images."""
        mesh = targets.mesh

        positional_emb_rows_var = mtf.get_variable(
            mesh,
            "positional_emb_rows",
            mtf.Shape([self.pos_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=self.activation_type)
        positional_emb_cols_var = mtf.get_variable(
            mesh,
            "positional_emb_cols",
            mtf.Shape([self.pos_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=self.activation_type)

        targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32)
        targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32)
        position_x = mtf.broadcast(
            mtf.gather(positional_emb_rows_var, targets_position_x,
                       self.pos_dim),
            mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))

        position_y = mtf.broadcast(
            mtf.gather(positional_emb_cols_var, targets_position_y,
                       self.pos_dim),
            mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))
        return position_x + position_y
  def create_positional_emb_2d(self, targets):
    """Learned 2d positional embedding for images."""
    mesh = targets.mesh

    positional_emb_rows_var = mtf.get_variable(
        mesh, "positional_emb_rows",
        mtf.Shape([self.pos_dim, self.model_dim]),
        initializer=tf.random_normal_initializer(),
        activation_dtype=self.activation_type)
    positional_emb_cols_var = mtf.get_variable(
        mesh, "positional_emb_cols",
        mtf.Shape([self.pos_dim, self.model_dim]),
        initializer=tf.random_normal_initializer(),
        activation_dtype=self.activation_type)

    targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32)
    targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32)
    position_x = mtf.broadcast(
        mtf.gather(positional_emb_rows_var, targets_position_x,
                   self.pos_dim),
        mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))

    position_y = mtf.broadcast(
        mtf.gather(positional_emb_cols_var, targets_position_y,
                   self.pos_dim),
        mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))
    return position_x + position_y
def norm(x, axis=None, epsilon=1e-5):
    axis = default(axis, x.shape[-1])

    u = mtf.reduce_mean(x, reduced_dim=axis)
    s = mtf.reduce_mean(mtf.square(x - u), reduced_dim=axis)

    u = mtf.broadcast(u, x.shape)
    s = mtf.broadcast(s, x.shape)

    return (x - u) * mtf.rsqrt(s + epsilon)
예제 #4
0
파일: gpt2.py 프로젝트: doinker/GPTNeo
def axial_positional_emb(embd_dim, mesh, params, variable_dtype):
    # Use axial position encoding
    axial_dim_1, axial_dim_2 = params["axial_pos_emb"]

    axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2)
    dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))]

    axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]),
                                   initializer=tf.random_normal_initializer(stddev=0.01),
                                   master_dtype=variable_dtype.master_dtype,
                                   slice_dtype=variable_dtype.slice_dtype,
                                   activation_dtype=variable_dtype.activation_dtype)

    axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]),
                                   initializer=tf.random_normal_initializer(stddev=0.01),
                                   master_dtype=variable_dtype.master_dtype,
                                   slice_dtype=variable_dtype.slice_dtype,
                                   activation_dtype=variable_dtype.activation_dtype)

    axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),
                                   (axial_wpe_1, axial_wpe_2))
    wpe = (axial_wpe_1 + axial_wpe_2) / 2

    wpe = mtf.reshape(wpe, [axial_dim, embd_dim])

    return wpe
예제 #5
0
파일: gpt2.py 프로젝트: doinker/GPTNeo
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh):
    """memory / key values from all attention paper"""

    dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv)
    emb_dim = k.shape[-1]
    mem_std = 1 / math.sqrt(emb_dim.size)

    mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
                             initializer=tf.random_normal_initializer(stddev=mem_std),
                             master_dtype=variable_dtype.master_dtype,
                             slice_dtype=variable_dtype.slice_dtype,
                             activation_dtype=variable_dtype.activation_dtype,
                             )
    mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
                             initializer=tf.random_normal_initializer(stddev=mem_std),
                             master_dtype=variable_dtype.master_dtype,
                             slice_dtype=variable_dtype.slice_dtype,
                             activation_dtype=variable_dtype.activation_dtype)

    mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]),
                       (mem_k, mem_v))
    mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"),
                       (mem_k, mem_v))

    k = mtf.concat([mem_k, k], "sequence")
    v = mtf.concat([mem_v, v], "sequence")
    return k, v
    def add_step_timing_signal_func(self, context, x, step):
        """Add n-dimensional embedding as the step (vertical) timing signal.

    Args:
      context: mtf context
      x: a tensor with shape [batch, length, depth]
      step: step

    Returns:
      a Tensor with the same shape as x.

    """
        if self.recurrence_type == "act":
            num_steps = self.act_max_steps
        else:
            num_steps = self.num_rec_steps
        channels = x.shape.dims[-1]

        if self.step_timing_signal_type == "learned":
            signal = self.get_layer_timing_signal_learned_1d(
                context, channels, step, num_steps)
        elif self.step_timing_signal_type == "sinusoid":
            signal = self.get_layer_timing_signal_sinusoid_1d(
                context, channels, step, num_steps)
        if self.add_or_concat_timing_signal == "add":
            x_with_timing = x + mtf.cast(signal, x.dtype)
        elif self.add_or_concat_timing_signal == "concat":
            batch_dim = x.shape.dims[0]
            out_shape = mtf.Shape([batch_dim] + x.shape.dims[1:])
            signal_tiled = mtf.broadcast(signal, out_shape)
            x_with_timing = mtf.concat(
                (x, signal_tiled),
                concat_dim_name=signal_tiled.dimension_names[-1])

        return x_with_timing
예제 #7
0
 def get_attn_mask(self, mesh, nd, ns):
     if not exists(self.attn_mask):
         i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
         j = mtf.range(mesh, ns, tf.int32)
         i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
         self.attn_mask = mtf.cast(mtf.less(
             i, j), self.variable_dtype.activation_dtype) * -1e10
     return self.attn_mask
예제 #8
0
파일: utils.py 프로젝트: zxhjiutian/gpt-neo
def biasmask_attn_weights(mesh, nd, ns, variable_dtype):
    # The old mask_attn_weights applied directly to the QK;
    # this returns a bias that the attention code from mtf adds to the attention matrix.
    # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
    # n_src and n_dest are both the same, i.e equal to sequence length
    # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T
    # Information flows from k and v (memory_length) to q (sequence)
    i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size
    j = mtf.range(mesh, ns, tf.int32)
    i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
    dtype = variable_dtype.activation_dtype
    return mtf.cast(mtf.less(i, j), dtype) * -1e10
    def add_position_timing_signal_func(self, context, x, step):
        """Add n-dimensional embedding as the position (horizontal) timing signal.

    Args:
      context: mtf context
      x: a tensor with shape [batch, length, depth]
      step: step

    Returns:
      a Tensor with the same shape as x.

    """

        if not self.position_start_index:
            index = 0

        elif self.position_start_index == "random":
            # Shift all positions randomly
            # TODO(dehghani): What would be reasonable for max number of shift?
            index = mtf.random_uniform(context.mesh, [],
                                       maxval=x.shape.dims[1].size,
                                       dtype=tf.int32)

        elif self.position_start_index == "step":
            # Shift positions based on the step
            if self.recurrence_type == "act":
                num_steps = self.act_max_steps
            else:
                num_steps = self.num_rec_steps
            index = mtf.cast(x.shape.dims[1].size * step / num_steps,
                             dtype=tf.int32)

        length = context.length_dim
        channels = context.model.model_dim
        signal = self.get_timing_signal_1d(context,
                                           length,
                                           channels,
                                           start_index=index)

        if self.add_or_concat_timing_signal == "add":
            x_with_timing = x + mtf.cast(signal, x.dtype)
        # Unimplemented
        if self.add_or_concat_timing_signal == "concat":
            batch_dim = x.shape.dims[0]
            out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:])
            signal_tiled = mtf.broadcast(signal, out_shape)
            x_with_timing = mtf.concat(
                (x, signal_tiled),
                concat_dim_name=signal_tiled.dimension_names[-1])

        return x_with_timing
예제 #10
0
 def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
     if noising_spec["type"] == "mask":
         # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
         return targets * mtf.cast(
             mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
                         noising_spec["prob"]), targets.dtype)
     elif noising_spec["type"] == "random_zipfian":
         # Replace a randomly-chosen noising_spec["prob"] of input tokens.
         # Rather than drawing the replacement tokens uniformly, we sample from
         #   a distribution favoring lower token-ids, assuming that the ids have
         #   been assigned in frequency order.  The probability of choosing an
         #   id is proportional to 1/(id+10)
         logits = mtf.log(1.0 / (mtf.range(
             targets.mesh, self.targets_vocab_dim, dtype=tf.float32) +
                                 10.0))
         logits = mtf.broadcast(logits,
                                new_shape=targets.shape + logits.shape)
         r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
         use_noise = mtf.less(
             mtf.random_uniform(targets.mesh, targets.shape),
             noising_spec["prob"])
         return mtf.where(use_noise, r, targets)
     elif noising_spec["type"] == "transformer":
         # Train a small transformer to fill in masked out values, then
         # sample from it.
         hparams = self._hparams
         if hparams.mode != tf.estimator.ModeKeys.TRAIN:
             raise NotImplementedError("Not implemented")
         noiser_hparams = copy.copy(self._hparams)
         noiser_hparams.del_hparam("mode")
         noiser_hparams.override_from_dict(noising_spec["overrides"])
         with tf.variable_scope("noiser"):
             noiser = MtfTransformer(noiser_hparams,
                                     mode=hparams.mode,
                                     problem_hparams=self._problem_hparams)
             logits, loss = noiser._mtf_model_fn(  # pylint: disable=protected-access
                 self._original_features, targets.mesh)
             samples = mtf.sample_with_temperature(logits,
                                                   self.targets_vocab_dim)
         losses.append(loss)
         return samples
     else:
         raise ValueError("unknown noising spec %s" % noising_spec)
def attention(x, dim_head, dim_features_head, scope='attn', causal=False):
    with tf.variable_scope(scope):
        mesh, batch, seq, dim = x.mesh, *x.shape

        dim_heads = mtf.Dimension('dim_heads',
                                  dim_head.size * dim_features_head.size)
        dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3)
        qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv')

        q, k, v = mtf.split(qkv, dim_intermediate, 3)
        q, k, v = map(
            lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head]
                                  ), (q, k, v))
        q, k, v = map(
            lambda t: mtf.transpose(
                t, [batch, dim_head, seq, dim_features_head]), (q, k, v))

        k, v = map(
            lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'),
            (k, v))
        mem_len_dim = v.shape[-2]

        dots = mtf.layers.us_einsum([q, k],
                                    [batch, dim_head, seq, mem_len_dim])

        if causal:
            i = mtf.range(mesh, seq, tf.int32)
            j = mtf.range(mesh, mem_len_dim, tf.int32)
            i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j))
            mask = mtf.less(i + mem_len_dim.size - seq.size, j)
            mask = mtf.cast(mask, tf.float32) * -1e10
            dots += mask

        attn = mtf.softmax(dots, mem_len_dim)
        out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head])

        out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head])
        out = mtf.reshape(out, [batch, seq, dim_heads])

        combined_out = linear(out, dim, scope='combine_output')
        return combined_out
예제 #12
0
 def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
   if noising_spec["type"] == "mask":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
     return targets * mtf.cast(
         mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
                     noising_spec["prob"]), targets.dtype)
   elif noising_spec["type"] == "random_zipfian":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens.
     # Rather than drawing the replacement tokens uniformly, we sample from
     #   a distribution favoring lower token-ids, assuming that the ids have
     #   been assigned in frequency order.  The probability of choosing an
     #   id is proportional to 1/(id+10)
     logits = mtf.log(1.0 / (mtf.range(
         targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0))
     logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape)
     r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     use_noise = mtf.less(
         mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"])
     return mtf.where(use_noise, r, targets)
   elif noising_spec["type"] == "transformer":
     # Train a small transformer to fill in masked out values, then
     # sample from it.
     hparams = self._hparams
     if hparams.mode != tf.estimator.ModeKeys.TRAIN:
       raise NotImplementedError("Not implemented")
     noiser_hparams = copy.copy(self._hparams)
     noiser_hparams.del_hparam("mode")
     noiser_hparams.override_from_dict(noising_spec["overrides"])
     with tf.variable_scope("noiser"):
       noiser = MtfTransformer(
           noiser_hparams,
           mode=hparams.mode,
           problem_hparams=self._problem_hparams)
       logits, loss = noiser._mtf_model_fn(  # pylint: disable=protected-access
           self._original_features, targets.mesh)
       samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     losses.append(loss)
     return samples
   else:
     raise ValueError("unknown noising spec %s" % noising_spec)
예제 #13
0
파일: gpt2.py 프로젝트: doinker/GPTNeo
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None):
    # x :: [batch, seq, n_embd]
    x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh

    # n_state is the same as config["n_embd"], which is also the same as dim_embd.
    assert n_state.size % params["n_head"] == 0

    dim_heads = mtf.Dimension("heads", params["n_head"])

    num_mem_kv = params.get("num_mem_kv", 0)
    use_num_mem_kv = num_mem_kv > 0

    with tf.variable_scope(scope):
        # Compute attention inputs
        dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"])
        mtfparams = mtf.transformer.attention.attention_params_simple(
            x.mesh,
            io_dim=dim_embd,
            kv_dim=dim_kv,
            heads_dim=dim_heads,
            variable_dtype=variable_dtype
        )
        q = mtfparams.compute_q(x)
        k = mtfparams.compute_k(x)
        v = mtfparams.compute_v(x)

        if is_incremental_inference(context):
            one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)
            inv_one_hot = 1.0 - one_hot
            old_k, old_v = context.get_states(2)
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot

        if exists(context):
            context.record_new_states([k, v])

        with tf.variable_scope("attention"):
            if attention_type == "local":
                # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                radius = params.get("local_attention_radius", 256)

                if is_incremental_inference(context):
                    q *= one_hot

                a = mtf_transformer.attention.local_attention_1d(
                    q, k, v,
                    length_dim=k.shape[1],
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    radius=radius,
                    length_dim_num_splits=1,
                    fully_autoregressive=params["causal"],
                    attention_kwargs={},
                )

                if is_incremental_inference(context):
                    a = mtf.gather(a, context.position - 1, dim_seq)

            elif attention_type == "global":

                # TODO: pass in fake context
                # Broadcast mask bias across batch and heads
                if exists(bias):
                    if not is_incremental_inference(context):
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])
                    else:
                        # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                        bias = mtf.gather(bias, context.position - 1, dim_seq)
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])

                # memory key / values, from all-attention paper
                if use_num_mem_kv:
                    k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)

                k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
                v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)

                attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0

                a = mtf_transformer.attention.attention(
                    q, k, v,
                    memory_length_dim=memory_length_dim,
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    bias=broadcasted_bias,
                    dropout_rate=attn_dropout_rate
                )

            elif attention_type == "linear":
                linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention
                a = linear_attn_fn(q, k, v)

            else:
                raise NotImplementedError("Unknown attention type {}!".format(attention_type))

        with tf.variable_scope("compute_output"):
            a = mtfparams.compute_output(a, x_shape)

        with tf.variable_scope("compute_output_bias"):
            b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0),
                                 master_dtype=variable_dtype.master_dtype,
                                 slice_dtype=variable_dtype.slice_dtype,
                                 activation_dtype=variable_dtype.activation_dtype)
            a += b

        if params["mode"] == "train" and params["res_dropout"] > 0:
            a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
        return a
    def _call_internal(self,
                       context,
                       inputs,
                       targets=None,
                       attributes=None,
                       z=None):
        """Compute logits based on inputs (all positions in parallel).
        Also updates context if applicable.
        Args:
          context: a Context
          inputs: a Tensor
          targets: an optional Tensor
          attributes: an optional Tensor
        Returns:g
          logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim]
        """
        mesh = inputs.mesh
        if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims:
            # Training an ensemble where all models are trained on the same examples.
            inputs = mtf.broadcast(inputs,
                                   [self.ensemble_dim] + inputs.shape.dims)
            if self.ensemble_dim not in attributes.shape.dims:
                attributes = mtf.broadcast(attributes, [self.ensemble_dim] +
                                           attributes.shape.dims)
            if targets:
                targets = mtf.broadcast(targets, [self.ensemble_dim] +
                                        targets.shape.dims)
        if "embedding" in context.shared_params:
            vocab_embedding = context.shared_params["embedding"]
        else:
            vocab_embedding = VocabEmbedding(mesh,
                                             self.input_vocab_dim,
                                             self.model_dim,
                                             context.variable_dtype,
                                             name="embedding",
                                             ensemble_dim=self.ensemble_dim)
        x = vocab_embedding.ids_to_embedding(inputs)
        if self.positional_embedding:
            if "positional_embedding" in context.shared_params:
                pos_emb_var = context.shared_params["positional_embedding"]
            else:
                pos_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.max_length_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "positional_embedding",
                    ensemble_dim=self.ensemble_dim)
            if (context.length_dim is not None
                    and context.length_dim.size > self.max_length_dim.size):
                message = (
                    "Length dimenison exceeds size of positional embedding table. "
                    "length_dim.size > max_length_dim.size %s vs %s." %
                    (context.length_dim, self.max_length_dim))
                if context.position_is_default:
                    # Definitely getting overflow in this case.
                    raise ValueError(message)
                else:
                    tf.logging.warning(
                        message +
                        " This may be OK if there are several shorter sequences packed "
                        "together.  Otherwise, the later positions will get zeros."
                    )
            if context.position_is_default:
                pos_emb = mtf.rename_dimension(
                    mtf.slice(pos_emb_var, 0, context.length_dim.size,
                              self.max_length_dim.name),
                    self.max_length_dim.name, context.length_dim.name)
            else:
                pos_emb = mtf.gather(pos_emb_var,
                                     context.position,
                                     self.max_length_dim,
                                     output_shape=x.shape)
            x += pos_emb

        if self.attribute_embedding:
            if "attribute_embedding" in context.shared_params:
                att_emb_var = context.shared_params["attribute_embedding"]
            else:
                att_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.attribute_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "attribute_embedding",
                    ensemble_dim=self.ensemble_dim)

            att_emb = mtf.gather(att_emb_var,
                                 attributes,
                                 self.attribute_dim,
                                 output_shape=x.shape)
            # Addition of x and attribute
            # x *= LAMBDA_ATTRIBUTE * sty_emb #

            # Concatenation of x and attribute
            x_attribute = mtf.concat([x, att_emb], self.model_dim.name)
            x = mtf.layers.dense(x_attribute,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="comb_x_attribute")

        if z:
            z = mtf.layers.dense(z,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="z")
            # raise ValueError("x shape=%s , z shape=%s" % (x.shape, z.shape))
            x += z

        x = self.layer_stack.call(context, x)
        if self.output_vocab_dim is None:
            return x
        if self.shared_embedding_and_softmax_weights:
            logits = vocab_embedding.hidden_to_logits(x)
        else:
            logits = mtf.layers.dense(x,
                                      self.output_vocab_dim,
                                      use_bias=False,
                                      variable_dtype=context.variable_dtype,
                                      reduced_dims=x.shape.dims[-1:],
                                      name="logits")
        if targets is not None and context.losses is not None:
            context.losses.append(
                self._compute_loss(context, logits, targets,
                                   self.output_vocab_dim))
        if self.ensemble_dim:
            logits = reduce_ensemble_logits(logits, self.ensemble_dim,
                                            self.output_vocab_dim)
        return logits
예제 #15
0
    def attention(self,
                  x,
                  n_state,
                  mask,
                  attention_type="global",
                  name="attn"):
        # x :: [batch, seq, n_embd]
        batch_dim, seq_dim, embd_dim = x_shape = x.shape
        assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads"
        with tf.variable_scope(name):
            # Compute attention inputs
            mtfparams = mtf.transformer.attention.attention_params_simple(
                x.mesh,
                io_dim=self.dimensions["embed_dim"],
                kv_dim=self.dimensions["kv_dim"],
                heads_dim=self.dimensions["heads_dim"],
                variable_dtype=self.variable_dtype)
            q = mtfparams.compute_q(x)
            k = mtfparams.compute_k(x)
            v = mtfparams.compute_v(x)

            if self.is_incremental_inference:
                one_hot = mtf.one_hot(self.context.position - 1,
                                      seq_dim,
                                      dtype=self.variable_dtype.master_dtype)
                inv_one_hot = 1.0 - one_hot
                old_k, old_v = self.context.get_states(2)
                k = old_k * inv_one_hot + k * one_hot
                v = old_v * inv_one_hot + v * one_hot

            if exists(self.context):
                self.context.record_new_states([k, v])

            with tf.variable_scope("attention"):
                if attention_type == "local":
                    # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                    radius = self.params.get("local_attention_radius", 256)
                    if self.is_incremental_inference:
                        q *= one_hot
                    a = mtf_transformer.attention.local_attention_1d(
                        q,
                        k,
                        v,
                        length_dim=k.shape[1],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        radius=radius,
                        length_dim_num_splits=1,
                        fully_autoregressive=True,
                        attention_kwargs={},
                    )
                    if self.is_incremental_inference:
                        a = mtf.gather(a, self.context.position - 1, seq_dim)

                elif attention_type == "global":
                    if exists(mask):
                        if not self.is_incremental_inference:
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-2], mask.shape[-1]
                                ])  # TODO: not sure this is correct
                        else:
                            # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                            mask = mtf.gather(mask, self.context.position - 1,
                                              seq_dim)
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-1]
                                ])

                    k = mtf.replace_dimensions(
                        k, k.shape[1], self.dimensions["memory_len_dim"])
                    v = mtf.replace_dimensions(
                        v, v.shape[1], self.dimensions["memory_len_dim"])

                    attn_dropout_rate = self.params.get(
                        "attention_dropout", 0) if self.mode == "train" else 0

                    a = mtf_transformer.attention.attention(
                        q,
                        k,
                        v,
                        memory_length_dim=self.dimensions["memory_len_dim"],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        bias=broadcasted_mask,
                        dropout_rate=attn_dropout_rate)
                else:
                    raise NotImplementedError(
                        "Unknown attention type {}!".format(attention_type))

            with tf.variable_scope("compute_output"):
                a = mtfparams.compute_output(a, x_shape)

            with tf.variable_scope("compute_output_bias"):
                b = mtf.get_variable(
                    x.mesh,
                    "o_b", [embd_dim],
                    initializer=tf.constant_initializer(0),
                    master_dtype=self.variable_dtype.master_dtype,
                    slice_dtype=self.variable_dtype.slice_dtype,
                    activation_dtype=self.variable_dtype.activation_dtype)
                a += b
            residual_dropout = self.params.get("residual_dropout", 0)
            if self.mode == "train" and residual_dropout > 0:
                a = mtf.dropout(a, rate=residual_dropout, name="res_dropout")
            return a