Ejemplo n.º 1
0
    def get_indices(self, keys: mtf.Tensor,
                    query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]:
        """Generate score and indices for the query."""
        score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3])
        scores = mtf.einsum([query, keys],
                            output_shape=score_shape)  # [b, l, h, 2, n_keys]
        knn_dim = mtf.Dimension("knn", self.knn)
        scores, indices = mtf.top_k(scores, score_shape.dims[-1],
                                    knn_dim)  # [b, l, h, 2, knn]

        # Computes the top cartesian products and their indices
        knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2)
        scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2])
        scores2 = mtf.rename_dimension(scores2, "knn", "knn2")
        out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:])
        all_scores = mtf.add(scores1, scores2, output_shape=out_shape)
        all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:],
                                            knn_square_dim)

        indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2])
        indices1 = mtf.multiply(indices1, self.n_keys)
        indices2 = mtf.rename_dimension(indices2, "knn", "knn2")
        all_indices = mtf.add(indices1, indices2, output_shape=out_shape)
        all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:],
                                             knn_square_dim)

        scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1],
                                         knn_dim)
        return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
Ejemplo n.º 2
0
def rename_dimension(x, old_dim_name, new_dim_name):
    assert isinstance(x, mtf.Tensor)
    if old_dim_name == new_dim_name:
        return x

    if old_dim_name.startswith('axis') and new_dim_name.startswith('axis'):
        tmp_dim_name = utils.RandName()
        x = mtf.rename_dimension(x, old_dim_name, tmp_dim_name)
        old_dim_name = tmp_dim_name

    return mtf.rename_dimension(x, old_dim_name, new_dim_name)
Ejemplo n.º 3
0
def linear_attention(q, k, v):
    batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
    q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
    k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")

    dim_in = k.shape[-1]

    q = mtf.softmax(q, dim_in)
    k = mtf.softmax(k, seq_dim)

    context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out])
    attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
    return attn
Ejemplo n.º 4
0
    def forward(self, features, return_loss=True, return_logits=False):
        inputs = features["tokens"]
        tokens = self.positional_embedding(self.embedding(inputs, "embedding"),
                                           "positional_embedding")

        mask = self.get_attn_mask(tokens.mesh, tokens.shape[1],
                                  self.dimensions["memory_len_dim"])
        out = self.transformer(tokens, mask=mask)
        logits = self.to_logits(out)
        if not return_loss:
            return logits

        labels = pad(inputs, [0, 1],
                     dim_name="total_seq_dim",
                     pad_value=self.eos_token_id)
        indices = mtf.range(labels.mesh,
                            mtf.Dimension("range", labels.shape[1].size - 1),
                            tf.int32,
                            name="labels_indices") + 1
        labels = mtf.gather(labels, indices, dim=labels.shape[1])
        labels = mtf.rename_dimension(labels, "range", "total_seq_dim")
        loss, loss_batch = self._loss(logits, labels)
        if return_logits and return_loss:
            # Cast back to checkpoint dtype
            logits = mtf.cast(logits, self.variable_dtype.master_dtype)
            return loss, loss_batch, logits
        return loss, loss_batch
Ejemplo n.º 5
0
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh):
    """memory / key values from all attention paper"""

    dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv)
    emb_dim = k.shape[-1]
    mem_std = 1 / math.sqrt(emb_dim.size)

    mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
                             initializer=tf.random_normal_initializer(stddev=mem_std),
                             master_dtype=variable_dtype.master_dtype,
                             slice_dtype=variable_dtype.slice_dtype,
                             activation_dtype=variable_dtype.activation_dtype,
                             )
    mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
                             initializer=tf.random_normal_initializer(stddev=mem_std),
                             master_dtype=variable_dtype.master_dtype,
                             slice_dtype=variable_dtype.slice_dtype,
                             activation_dtype=variable_dtype.activation_dtype)

    mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]),
                       (mem_k, mem_v))
    mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"),
                       (mem_k, mem_v))

    k = mtf.concat([mem_k, k], "sequence")
    v = mtf.concat([mem_v, v], "sequence")
    return k, v
Ejemplo n.º 6
0
def reshape(x, new_shape):
    old_shape = x.shape
    assert len(old_shape) == len(new_shape)
    for o, n in zip(old_shape.dims, new_shape.dims):
        if (o.name != n.name) and (o.name.startswith('axis')
                                   and n.name.startswith('axis')):
            x = mtf.rename_dimension(x, o.name, utils.RandName())
    return mtf.reshape(x, new_shape)
Ejemplo n.º 7
0
def causal_linear_attention(q, k, v, epsilon=1e-6):
    batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
    q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
    k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")

    dim_in = k.shape[-1]

    q = mtf.softmax(q, dim_in)
    k = mtf.exp(k)

    cumulative_k = mtf.cumsum(k, seq_dim)
    context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out])
    cumulative_context = mtf.cumsum(context, seq_dim)

    cumulative_context /= (cumulative_k + epsilon)
    attn = mtf.einsum([q, cumulative_context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
    return attn
    def hidden_to_logits(self, hidden: mtf.Tensor,
                         context: transformer.Context) -> mtf.Tensor:
        """Function called by mtf transformer to get the logits.

    Args:
      hidden: an mtf.Tensor, hidden model states of the final decoder layer.
      context: a transformer.Context, the context used for the call to the
        transformer.

    Returns:
      An mtf.Tensor, the logits.
    """
        hidden *= self._output_dim.size**-0.5

        component_contexts = mtf.einsum([
            mtf.rename_dimension(hidden, self._output_dim.name,
                                 self._copy_output_dim.name),
            self._context_weights,
        ],
                                        reduced_dims=[self._copy_output_dim])
        component_contexts = mtf.tanh(component_contexts +
                                      self._context_weights_bias)
        component_logits = mtf.einsum(
            [component_contexts, self._embedding_weights],
            reduced_dims=[self._output_dim])
        component_logits = self._dropout(component_logits, context)

        prior_tanh = mtf.tanh(
            mtf.einsum([self._prior_weights, hidden],
                       reduced_dims=[self._output_dim]) +
            self._prior_weights_bias)
        prior_tanh = self._dropout(prior_tanh, context)
        prior_shared_logits = mtf.einsum([self._prior_gates_vector, hidden],
                                         reduced_dims=[self._output_dim])
        prior_frequent_vocab_logits = (
            mtf.einsum([self._prior_vocab_vector, prior_tanh]) +
            prior_shared_logits + self._prior_bias)
        prior_logits = mtf.concat([
            prior_frequent_vocab_logits,
            mtf.ones(self._mesh,
                     mtf.Shape([self._rare_vocab_dim]),
                     dtype=prior_shared_logits.dtype) * prior_shared_logits
        ], self._vocab_dim.name)
        if context.train and self._noise_std_dev != 0.0:
            prior_logits += mtf.random_normal(self._mesh,
                                              prior_logits.shape,
                                              stddev=self._noise_std_dev)
        prior_proportions = self._sigmoid_tree(prior_logits)

        logits = mtf.einsum([component_logits, prior_proportions],
                            reduced_dims=[self._gates_dim])
        return self._rearrange_sentinels(logits)
Ejemplo n.º 9
0
    def call(self, context, x, losses=None):
        """Call the layer."""

        if self.canine_mode:
            # This is the canine-like ByT5 + LASC baseline in paper.
            return self.call_canine_encoder(context, x, losses=losses)

        if self.conv_type:
            if self.conv_type == "conv1d":
                tf.logging.info("Using 1d conv")
                tmp_output = mtf.Dimension("tmp_dim", x.shape[-1].size)
                orig_dim = x.shape[-1]
                x = mtf.layers.conv1d(x,
                                      tmp_output,
                                      filter_size=self.filter_size,
                                      stride=1)
                x = mtf.rename_dimension(x, "tmp_dim", orig_dim.name)
                tf.logging.info(x)
        if self.norm:
            x = sublayer_rms_norm(x, None, context)
        o = x
        olength = o.shape.get_dim_by_name("length")
        o = custom_attention.gradient_based_subword_tokenization(
            o,
            olength,
            downsample=self.downsample_query,
            use_offsets=self.use_offsets,
            consider_chars_as_blocks=self.consider_chars_as_blocks,
            use_block_pos_embedding=self.use_block_pos_embedding,
            memory_embeddings=self.num_memory_slots,
            context=context,
            block_mixing_mode=self.block_mixing_mode,
            activation=self.rank_activation,
            downsample_function=self.gbst_pool)
        new_length_dim = o.shape.get_dim_by_name("length")
        context.length_dim = new_length_dim
        new_context_position = context.get_position()
        context.position = new_context_position
        context.sequence_id = mtf.slice(context.sequence_id,
                                        begin=0,
                                        size=new_length_dim.size,
                                        slice_dim_name=new_length_dim.name)
        if self.use_ffn:
            # not actually used in Charformer.
            tf.logging.info("Using FFN")
            o2 = self.ffn.call(context, o)
            o = o + o2
        if self.norm:
            o = sublayer_rms_norm(o, None, context)
        olength = o.shape.get_dim_by_name("length")
        return o, context
Ejemplo n.º 10
0
    def call(self, context, x, losses=None):
        """Call the layer."""
        wq, wk, wv, wo = mtf.layers.multihead_attention_params(
            context.mesh, self.heads_dim, context.model_dim, self.kv_dim,
            context.variable_dtype)
        memory_length = mtf.Dimension("memory_length", context.length_dim.size)
        q = mtf.einsum([x, wq], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            m = x
        else:
            m = mtf.rename_dimension(x, context.length_dim.name,
                                     "memory_length")
        k = mtf.einsum([m, wk], reduced_dims=[context.model_dim])
        v = mtf.einsum([m, wv], reduced_dims=[context.model_dim])
        if context.mode == "incremental":
            old_k, old_v = context.get_states(2)
            one_hot = mtf.one_hot(context.position,
                                  memory_length,
                                  dtype=context.activation_dtype)
            inv_one_hot = 1.0 - one_hot
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot
        if context.mode == "incremental" or context.mode == "first_part":
            context.record_new_states([k, v])
        masks = []
        if context.autoregressive:
            masks.append(
                mtf.cast(
                    mtf.less(
                        context.position,
                        mtf.range(context.mesh, memory_length,
                                  dtype=tf.int32)), context.activation_dtype) *
                -1e9)
        if (context.sequence_id is not None
                and isinstance(context.sequence_id, mtf.Tensor)
                and context.length_dim in context.sequence_id.shape):
            masks.append(
                mtf.cast(
                    mtf.not_equal(
                        context.sequence_id,
                        mtf.layers.rename_length_to_memory_length(
                            context.sequence_id)), context.activation_dtype) *
                -1e9)
        mask = mtf.add_n(masks) if masks else None

        o = mtf.layers.dot_product_attention_v2(
            q, k, v, memory_length, self.kv_dim, self.kv_dim, mask,
            self.dropout_rate if context.train else 0.0, [context.length_dim])
        return mtf.einsum([o, wo],
                          x.shape,
                          reduced_dims=[self.heads_dim, self.kv_dim])
Ejemplo n.º 11
0
 def call_canine_encoder(self, context, x, losses=None):
     """Call Canine baseline encoder (Byte level T5 + LASC in paper)."""
     # local attention
     params = self.make_params(context)
     q = params.compute_q(x)
     if self.shared_kv:
         kv = params.compute_kv(x)
         k = kv
         v = kv
     else:
         k = params.compute_k(x)
         v = params.compute_v(x)
     # local attention
     output_shape = x.shape
     x = custom_attention.local_attention_1d(
         q,
         k,
         v,
         length_dim=context.length_dim,
         length_dim_num_splits=1,
         key_dim=self.kv_dim,
         value_dim=self.kv_dim,
         fully_autoregressive=False,
         radius=self.radius,
         sequence_id=context.sequence_id,
         write_priority=context.write_priority,
         read_priority=context.read_priority,
         context=context,
         attention_kwargs=self.attention_kwargs_from_context(context))
     o = params.compute_output(x, output_shape=output_shape)
     # strided convolutions
     tmp_output = mtf.Dimension("tmp_dim", o.shape[-1].size)
     # downsample query args is reused here for "r"
     o = mtf.layers.conv1d(o,
                           tmp_output,
                           filter_size=self.filter_size,
                           stride=int(self.downsample_query))
     o = mtf.rename_dimension(o, "tmp_dim", "d_model")
     tf.logging.info(o)
     new_length_dim = o.shape.get_dim_by_name("length")
     context.length_dim = new_length_dim
     new_context_position = context.get_position()
     context.position = new_context_position
     context.sequence_id = mtf.slice(context.sequence_id,
                                     begin=0,
                                     size=new_length_dim.size,
                                     slice_dim_name=new_length_dim.name)
     return o, context
    def hidden_to_logits(self, hidden: mtf.Tensor,
                         context: transformer.Context) -> mtf.Tensor:
        """Function called by mtf transformer to get the logits.

    Note that we are taking the log of a mixture of softmaxes. The logits will
    then go through a softmax. This could potentially run into numerical
    stability issues. If that happens, try setting the activation_dtype to
    float32.

    Args:
      hidden: hidden model states of the final decoder layer.
      context: the context used for the call to the
        transformer.

    Returns:
      The logits.
    """
        del context
        hidden *= self._output_dim.size**-0.5

        component_prior_logits = mtf.einsum([hidden, self._mixture_weights],
                                            reduced_dims=[self._output_dim])

        component_contexts = mtf.einsum([
            mtf.rename_dimension(hidden, self._output_dim.name,
                                 self._copy_output_dim.name),
            self._context_weights,
        ],
                                        reduced_dims=[self._copy_output_dim])
        component_contexts = mtf.tanh(component_contexts)
        component_logits = mtf.einsum(
            [component_contexts, self._embedding_weights],
            reduced_dims=[self._output_dim])

        component_prior_logits = mtf.log_softmax(
            component_prior_logits, reduced_dim=self._components_dim)
        component_logits = mtf.log_softmax(component_logits,
                                           reduced_dim=self._vocab_dim)

        logits = component_prior_logits + component_logits
        logits = mtf.reduce_logsumexp(logits, reduced_dim=self._components_dim)
        return logits
def attention(x, dim_head, dim_features_head, scope='attn', causal=False):
    with tf.variable_scope(scope):
        mesh, batch, seq, dim = x.mesh, *x.shape

        dim_heads = mtf.Dimension('dim_heads',
                                  dim_head.size * dim_features_head.size)
        dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3)
        qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv')

        q, k, v = mtf.split(qkv, dim_intermediate, 3)
        q, k, v = map(
            lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head]
                                  ), (q, k, v))
        q, k, v = map(
            lambda t: mtf.transpose(
                t, [batch, dim_head, seq, dim_features_head]), (q, k, v))

        k, v = map(
            lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'),
            (k, v))
        mem_len_dim = v.shape[-2]

        dots = mtf.layers.us_einsum([q, k],
                                    [batch, dim_head, seq, mem_len_dim])

        if causal:
            i = mtf.range(mesh, seq, tf.int32)
            j = mtf.range(mesh, mem_len_dim, tf.int32)
            i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j))
            mask = mtf.less(i + mem_len_dim.size - seq.size, j)
            mask = mtf.cast(mask, tf.float32) * -1e10
            dots += mask

        attn = mtf.softmax(dots, mem_len_dim)
        out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head])

        out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head])
        out = mtf.reshape(out, [batch, seq, dim_heads])

        combined_out = linear(out, dim, scope='combine_output')
        return combined_out
    def _call_internal(self,
                       context,
                       inputs,
                       targets=None,
                       attributes=None,
                       z=None):
        """Compute logits based on inputs (all positions in parallel).
        Also updates context if applicable.
        Args:
          context: a Context
          inputs: a Tensor
          targets: an optional Tensor
          attributes: an optional Tensor
        Returns:g
          logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim]
        """
        mesh = inputs.mesh
        if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims:
            # Training an ensemble where all models are trained on the same examples.
            inputs = mtf.broadcast(inputs,
                                   [self.ensemble_dim] + inputs.shape.dims)
            if self.ensemble_dim not in attributes.shape.dims:
                attributes = mtf.broadcast(attributes, [self.ensemble_dim] +
                                           attributes.shape.dims)
            if targets:
                targets = mtf.broadcast(targets, [self.ensemble_dim] +
                                        targets.shape.dims)
        if "embedding" in context.shared_params:
            vocab_embedding = context.shared_params["embedding"]
        else:
            vocab_embedding = VocabEmbedding(mesh,
                                             self.input_vocab_dim,
                                             self.model_dim,
                                             context.variable_dtype,
                                             name="embedding",
                                             ensemble_dim=self.ensemble_dim)
        x = vocab_embedding.ids_to_embedding(inputs)
        if self.positional_embedding:
            if "positional_embedding" in context.shared_params:
                pos_emb_var = context.shared_params["positional_embedding"]
            else:
                pos_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.max_length_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "positional_embedding",
                    ensemble_dim=self.ensemble_dim)
            if (context.length_dim is not None
                    and context.length_dim.size > self.max_length_dim.size):
                message = (
                    "Length dimenison exceeds size of positional embedding table. "
                    "length_dim.size > max_length_dim.size %s vs %s." %
                    (context.length_dim, self.max_length_dim))
                if context.position_is_default:
                    # Definitely getting overflow in this case.
                    raise ValueError(message)
                else:
                    tf.logging.warning(
                        message +
                        " This may be OK if there are several shorter sequences packed "
                        "together.  Otherwise, the later positions will get zeros."
                    )
            if context.position_is_default:
                pos_emb = mtf.rename_dimension(
                    mtf.slice(pos_emb_var, 0, context.length_dim.size,
                              self.max_length_dim.name),
                    self.max_length_dim.name, context.length_dim.name)
            else:
                pos_emb = mtf.gather(pos_emb_var,
                                     context.position,
                                     self.max_length_dim,
                                     output_shape=x.shape)
            x += pos_emb

        if self.attribute_embedding:
            if "attribute_embedding" in context.shared_params:
                att_emb_var = context.shared_params["attribute_embedding"]
            else:
                att_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.attribute_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "attribute_embedding",
                    ensemble_dim=self.ensemble_dim)

            att_emb = mtf.gather(att_emb_var,
                                 attributes,
                                 self.attribute_dim,
                                 output_shape=x.shape)
            # Addition of x and attribute
            # x *= LAMBDA_ATTRIBUTE * sty_emb #

            # Concatenation of x and attribute
            x_attribute = mtf.concat([x, att_emb], self.model_dim.name)
            x = mtf.layers.dense(x_attribute,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="comb_x_attribute")

        if z:
            z = mtf.layers.dense(z,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="z")
            # raise ValueError("x shape=%s , z shape=%s" % (x.shape, z.shape))
            x += z

        x = self.layer_stack.call(context, x)
        if self.output_vocab_dim is None:
            return x
        if self.shared_embedding_and_softmax_weights:
            logits = vocab_embedding.hidden_to_logits(x)
        else:
            logits = mtf.layers.dense(x,
                                      self.output_vocab_dim,
                                      use_bias=False,
                                      variable_dtype=context.variable_dtype,
                                      reduced_dims=x.shape.dims[-1:],
                                      name="logits")
        if targets is not None and context.losses is not None:
            context.losses.append(
                self._compute_loss(context, logits, targets,
                                   self.output_vocab_dim))
        if self.ensemble_dim:
            logits = reduce_ensemble_logits(logits, self.ensemble_dim,
                                            self.output_vocab_dim)
        return logits
Ejemplo n.º 15
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
  """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
  shortcut = inputs

  filter_h_dim = mtf.Dimension("filter_height", 3)
  filter_w_dim = mtf.Dimension("filter_width", 3)
  one_h_dim = mtf.Dimension("filter_height", 1)
  one_w_dim = mtf.Dimension("filter_width", 1)

  if projection_shortcut is not None:
    filters_dim = mtf.Dimension("filtersp", filters)
    kernel = mtf.get_variable(
        inputs.mesh, "kernel", mtf.Shape(
            [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
    shortcut = projection_shortcut(inputs, kernel)

  # First conv block
  filters1_dim = mtf.Dimension("filters1", filters)
  kernel1 = mtf.get_variable(
      inputs.mesh, "kernel1", mtf.Shape(
          [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel1,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Add Dropout?
  inputs = batch_norm_relu(inputs, is_training)

  # Second conv block
  filters2_dim = mtf.Dimension("filters2", 4*filters)
  kernel2 = mtf.get_variable(
      inputs.mesh, "kernel2", mtf.Shape(
          [filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel2,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)

  inputs = batch_norm_relu(inputs, is_training)

  # Third wide conv filter block
  filters3_dim = mtf.Dimension("filters3", filters)
  filters3_kernel = mtf.get_variable(
      inputs.mesh, "wide_kernel", mtf.Shape(
          [one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      filters3_kernel,
      strides,
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Althought the original resnet code has this batch norm, in our
  # setup this is causing no gradients to be passed. Investigate further.
  # inputs = batch_norm_relu(inputs, is_training, relu=True)

  # TODO(nikip): Maybe add residual with a projection?
  return mtf.relu(
      shortcut + mtf.rename_dimension(
          inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
Ejemplo n.º 16
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "encdec":
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf.layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num, layer_type in enumerate(hparams.decoder_layers):
        if layer_type == "enc_att":
          with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num):
            q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                mesh, self.heads_dim, self.model_dim,
                self.kv_dim, self.master_dtype, self.slice_dtype,
                self.activation_dtype)
            k = mtf.einsum(
                [encoder_output, k_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
            v = mtf.einsum(
                [encoder_output, v_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
          encdec_tensors.append((q_var, o_var, k, v))
        else:
          encdec_tensors.append(None)
      partial_targets = None
    elif hparams.transformer_type == "decoder":
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)
    else:
      raise ValueError(
          "hparams.model_type = %s not yet supported"
          % hparams.transformer_type)

    local_attention_window = mtf.Dimension(
        "local_attention_window", hparams.local_attention_window_size)
    if hparams.beam_size == 1:
      ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [self.heads_dim,
                                  local_attention_window, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [beam_dim, self.heads_dim,
                                  local_attention_window, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_states = []
    for layer in hparams.decoder_layers:
      if layer == "att":
        initial_states.extend(
            [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2)
      elif layer == "local_att":
        initial_states.extend(
            [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2)

    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_states = self._layer_stack(
            x,
            hparams.decoder_layers,
            encdec_attention_mask=encoder_attention_mask,
            step_num=step_num,
            encdec_tensors=encdec_tensors,
            states=states)
      logits = mtf.matmul(x, softmax_var)
      return logits, new_states

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf.beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if hparams.transformer_type == "encdec":
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf.beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu,
          dtype=self.activation_dtype)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
Ejemplo n.º 17
0
  def __init__(self,
               config,
               is_training,
               input_ids,
               input_mask=None,
               token_type_ids=None,
               scope=None,
               mesh_shape="",
               layout=""):
    self.config = copy.deepcopy(config)
    del config
    if not is_training:
      self.config.layer_output_dropout_prob = 0.0
      self.config.attention_probs_dropout_prob = 0.0
      self.config.feedforward_intermediate_dropout_prob = 0.0
    input_shape = input_ids.shape
    assert input_shape.ndims == 2

    self._seq_dim = input_shape.dims[1]
    self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
    self._extra_losses = []
    mesh = input_ids.mesh

    if token_type_ids is None:
      token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

    with tf.variable_scope(scope, default_name="bert"):
      with tf.variable_scope("embeddings"):
        # Perform embedding lookup on the word ids.
        self.embedding_table = mtf.get_variable(
            mesh, "word_embeddings",
            mtf.Shape([self.vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        self.word_embedding_output = mtf.gather(
            self.embedding_table, input_ids, self.vocab_dim)

        # Add positional embeddings and token type embeddings, then layer
        # normalize and perform dropout.
        self.embedding_output = self.word_embedding_output

        token_type_table = mtf.get_variable(
            mesh, "token_type_embeddings",
            mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        if token_type_ids is not None:
          self.embedding_output += mtf.gather(
              token_type_table, token_type_ids, self.token_type_vocab_dim)
        if self.config.position_signal == "embedding":
          full_position_table = mtf.get_variable(
              mesh, "position_embeddings",
              mtf.Shape([self.max_position_embeddings_dim, self.model_dim]),
              initializer=self.embedding_initializer)
          short_position_table = mtf.rename_dimension(
              mtf.slice(full_position_table, 0, self.seq_dim.size,
                        self.max_position_embeddings_dim.name),
              self.max_position_embeddings_dim.name, self.seq_dim.name)
          self.embedding_output += short_position_table
        self.embedding_output = self.normalize(self.embedding_output)
        self.embedding_output = mtf.dropout(
            self.embedding_output, is_training,
            keep_prob=1.0 - self.config.layer_output_dropout_prob)

      with tf.variable_scope("encoder"):
        attention_biases = []
        if input_mask:
          # [batch_dim, memory_seq_dim]
          attention_biases.append(
              (1.0 - mtf.to_float(mtf.replace_dimensions(
                  input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0)
        if self.config.position_signal == "relative_attention_bias":
          buckets_dim = mtf.Dimension("buckets", 32)
          rp_bucket = _relative_position_bucket(
              mtf.range(mesh, self.memory_seq_dim, tf.int32)
              - mtf.range(mesh, self.seq_dim, tf.int32),
              num_buckets=buckets_dim.size)
          bias_var = mtf.get_variable(
              mesh, "relative_attention_bias",
              [self.num_heads_dim, buckets_dim],
              initializer=tf.zeros_initializer())
          attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim))
        attention_bias = mtf.add_n(attention_biases)
        prev_layer_output = self.embedding_output
        self.all_encoder_layers = []
        for block_num in range(self.config.num_blocks):
          with tf.variable_scope("block_%d" % block_num):
            for layer_idx, layer_type in enumerate(self.config.block_layers):
              layer_name = layer_type
              count = self.config.block_layers[:layer_idx].count(layer_type)
              if count:
                layer_name += "_%d" % count
              with tf.variable_scope(layer_name):
                x = prev_layer_output
                if self.config.residual_structure == "direct":
                  x = self.normalize(x)
                if layer_type == "attention":
                  x = self.self_attention(x, attention_bias)
                elif layer_type == "feedforward":
                  x = self.feedforward(x)
                elif layer_type == "moe":
                  x = self.moe(x, layout, mesh_shape, input_mask, is_training)
                else:
                  raise ValueError("unknown layer type " + layer_type)
                x = mtf.dropout(
                    x, is_training,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)
                layer_output = prev_layer_output + x
                if self.config.residual_structure == "original":
                  layer_output = self.normalize(layer_output)
                prev_layer_output = layer_output
          self.all_encoder_layers.append(layer_output)

      self.sequence_output = prev_layer_output
      if self.config.residual_structure == "direct":
        self.sequence_output = self.normalize(self.sequence_output)

      # The "pooler" converts the encoded sequence tensor of shape
      # [batch_dim, seq_dim, hidden_size] to a tensor of shape
      # [batch_dim, hidden_size]. This is necessary for segment-level
      # (or segment-pair-level) classification tasks where we need a fixed
      # dimensional representation of the segment.
      with tf.variable_scope("pooler"):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token. We assume that this has been pre-trained
        first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim)
        self.pooled_output = mtf.layers.dense(
            first_token_tensor,
            reduced_dims=[self.model_dim],
            new_dims=[self.model_dim],
            activation=mtf.tanh,
            kernel_initializer=self.dense_initializer,
            use_bias=self.config.use_bias)
Ejemplo n.º 18
0
    def _mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        hparams = self._hparams
        targets = tf.to_int32(features["targets"])
        if len(targets.get_shape()) > 2:
            tf.logging.info("targets = %s" % targets)
            targets = tf.squeeze(targets, [2, 3])
        # pad targets to max_length
        def pad_to_max_length(x):
            extra_length = hparams.max_length - tf.shape(x)[1]
            x = tf.pad(x, [[0, 0], [0, extra_length]])
            x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
            return x

        targets = pad_to_max_length(targets)
        for key in [
                "targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"
        ]:
            if key in features:
                features[key] = pad_to_max_length(features[key])
        shifted_targets = common_layers.shift_right_2d(targets)

        targets = self._import_to_batch_by_length(targets, "targets", mesh,
                                                  hparams)
        shifted_targets = self._import_to_batch_by_length(
            shifted_targets, "shifted_targets", mesh, hparams)

        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = self._import_to_batch_by_length(
                features["targets_segmentation"], "targets_segmentation", mesh,
                hparams)
            targets_position = self._import_to_batch_by_length(
                features["targets_position"], "targets_position", mesh,
                hparams)
            decoder_self_attention_mask = (
                mtf.layers.attention_mask_autoregressive(
                    targets_position, dtype=self.activation_dtype) +
                mtf.layers.attention_mask_same_segment(
                    targets_segmentation, dtype=self.activation_dtype))
        else:
            targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
            decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
                targets_position, dtype=self.activation_dtype)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        extra_losses = []
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "decoder":
            encoder_output = None
            encoder_decoder_attention_mask = None
        else:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = pad_to_max_length(inputs)
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = self._import_to_batch_by_length(
                    features["inputs_segmentation"], "inputs_segmentation",
                    mesh, hparams)
                inputs_position = self._import_to_batch_by_length(
                    features["inputs_position"], "inputs_position", mesh,
                    hparams)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        inputs_segmentation, dtype=self.activation_dtype))
            else:
                inputs_position = mtf.range(mesh,
                                            self.length_dim,
                                            dtype=tf.int32)
                encoder_self_attention_mask = (
                    mtf.layers.attention_mask_ignore_padding(
                        inputs, dtype=self.activation_dtype))

            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.gather(positional_embedding_var, inputs_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_self_attention_mask,
                    losses=extra_losses)

        if hparams.transformer_type == "encdec":
            if "inputs_segmentation" in features:
                encoder_decoder_attention_mask = (
                    mtf.layers.attention_mask_same_segment(
                        targets_segmentation,
                        inputs_segmentation,
                        dtype=self.activation_dtype))
            else:
                encoder_decoder_attention_mask = encoder_self_attention_mask
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)

        if hparams.transformer_type != "encoder":
            # DECODER
            x = (mtf.gather(targets_embedding_var, shifted_targets,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, targets_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("decoder"):
                x = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encoder_output=encoder_output,
                    self_attention_mask=decoder_self_attention_mask,
                    encdec_attention_mask=encoder_decoder_attention_mask,
                    losses=extra_losses)
        logits = mtf.matmul(x, softmax_var)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
        off_value = hparams.label_smoothing / self._targets_vocab_size
        on_value = 1.0 - hparams.label_smoothing + off_value
        soft_targets = mtf.one_hot(targets,
                                   self.targets_vocab_dim,
                                   on_value=on_value,
                                   off_value=off_value,
                                   dtype=self.activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.targets_vocab_dim)
        weights = mtf.layers.weights_nonzero(targets,
                                             dtype=self.activation_dtype)
        loss = mtf.reduce_mean(loss * weights)
        for l in extra_losses:
            loss += l
        logits = mtf.to_float(logits)
        # combine batch dims
        if len(self.batch_dims) > 1:
            combined_batch_dim = mtf.Dimension(self.batch_dims[0].name,
                                               mtf.Shape(self.batch_dims).size)
            logits = mtf.reshape(logits,
                                 [combined_batch_dim] + logits.shape.dims[-2:])
        return logits, loss
Ejemplo n.º 19
0
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss
Ejemplo n.º 20
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        shortcut = projection_shortcut(inputs, filters_dim)

    # First conv block
    inputs = mtf.layers.conv2d_with_blocks(inputs,
                                           mtf.Dimension("filters1", filters),
                                           filter_size=[1, 1],
                                           strides=[1, 1],
                                           padding="SAME",
                                           h_blocks_dim=None,
                                           w_blocks_dim=col_blocks_dim,
                                           name="conv0")

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    inputs = mtf.layers.conv2d_with_blocks(inputs,
                                           mtf.Dimension(
                                               "filters2", 4 * filters),
                                           filter_size=[3, 3],
                                           strides=[1, 1],
                                           padding="SAME",
                                           h_blocks_dim=row_blocks_dim,
                                           w_blocks_dim=col_blocks_dim,
                                           name="conv1")

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    inputs = mtf.layers.conv2d_with_blocks(inputs,
                                           mtf.Dimension("filters3", filters),
                                           filter_size=[1, 1],
                                           strides=strides,
                                           padding="SAME",
                                           h_blocks_dim=None,
                                           w_blocks_dim=col_blocks_dim,
                                           name="conv2")

    # TODO(nikip): Althought the original resnet code has this batch norm, in our
    # setup this is causing no gradients to be passed. Investigate further.
    # inputs = batch_norm_relu(inputs, is_training, relu=True)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(shortcut + mtf.rename_dimension(
        inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
Ejemplo n.º 21
0
    def _sample(self, features, mesh):
        hparams = self._hparams
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "encdec":
            inputs = features["inputs"]
            while len(inputs.shape.as_list()) > 2:
                inputs = tf.squeeze(inputs, axis=2)
            actual_batch_size = tf.shape(inputs)[0]
            actual_length = tf.shape(inputs)[1]
            inputs = tf.pad(inputs,
                            [[0, hparams.batch_size - actual_batch_size],
                             [0, hparams.max_length - actual_length]])
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.reshape(positional_embedding_var,
                             mtf.Shape([self.length_dim, self.model_dim])))
            encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_attention_mask)
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)
            encdec_tensors = []
            for layer_num, layer_type in enumerate(hparams.decoder_layers):
                if layer_type == "enc_att":
                    with tf.variable_scope("decoder/enc_att_%d/enc_att" %
                                           layer_num):
                        q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                            mesh, self.heads_dim, self.model_dim, self.kv_dim,
                            self.master_dtype, self.slice_dtype,
                            self.activation_dtype)
                        k = mtf.einsum([encoder_output, k_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                        v = mtf.einsum([encoder_output, v_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                    encdec_tensors.append((q_var, o_var, k, v))
                else:
                    encdec_tensors.append(None)
            partial_targets = None
        elif hparams.transformer_type == "decoder":
            encdec_tensors = None
            encoder_output = None
            encoder_attention_mask = None
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs", None)
            if partial_targets is None:
                partial_targets = features.get("targets", None)
            if partial_targets is not None:
                partial_targets = common_layers.expand_squeeze_to_nd(
                    partial_targets, 2)
                partial_targets = tf.to_int32(partial_targets)
                partial_targets_batch = tf.shape(partial_targets)[0]
                partial_targets_length = tf.shape(partial_targets)[1]
                partial_targets = tf.pad(
                    partial_targets,
                    [[0, hparams.batch_size - partial_targets_batch],
                     [0, hparams.max_length - partial_targets_length]])
                partial_targets = self._import_to_batch_by_length(
                    partial_targets, "partial_targets", mesh, hparams)
        else:
            raise ValueError("hparams.model_type = %s not yet supported" %
                             hparams.transformer_type)

        local_attention_window = mtf.Dimension(
            "local_attention_window", hparams.local_attention_window_size)
        if hparams.beam_size == 1:
            ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
            kv_shape = mtf.Shape(
                self.batch_dims +
                [self.heads_dim, self.memory_length_dim, self.kv_dim])
            local_kv_shape = mtf.Shape(
                self.batch_dims +
                [self.heads_dim, local_attention_window, self.kv_dim])
        else:
            beam_dim = mtf.Dimension("beam", hparams.beam_size)
            ids_shape = mtf.Shape(self.batch_dims +
                                  [beam_dim, self.length_dim])
            kv_shape = mtf.Shape(self.batch_dims + [
                beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim
            ])
            local_kv_shape = mtf.Shape(self.batch_dims + [
                beam_dim, self.heads_dim, local_attention_window, self.kv_dim
            ])

        initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
        initial_states = []
        for layer in hparams.decoder_layers:
            if layer == "att":
                initial_states.extend(
                    [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] *
                    2)
            elif layer == "local_att":
                initial_states.extend([
                    mtf.zeros(
                        mesh, local_kv_shape, dtype=self.activation_dtype)
                ] * 2)

        def logits_fn(step_num, ids, states):
            """Produce logits for this step, and new states."""
            ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
            x = (mtf.gather(targets_embedding_var, ids_this_step,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, step_num,
                            self.max_length_dim))
            with tf.variable_scope("decoder"):
                x, new_states = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encdec_attention_mask=encoder_attention_mask,
                    step_num=step_num,
                    encdec_tensors=encdec_tensors,
                    states=states)
            logits = mtf.matmul(x, softmax_var)
            return logits, new_states

        if hparams.beam_size == 1:
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            return mtf.beam_search.greedy_decode(logits_fn,
                                                 initial_ids,
                                                 temperature=temperature,
                                                 initial_states=initial_states,
                                                 forced_ids=partial_targets,
                                                 use_tpu=hparams.use_tpu)
        else:
            if hparams.transformer_type == "encdec":
                input_length = mtf.reduce_sum(mtf.to_float(
                    mtf.cast(inputs, tf.bool)),
                                              reduced_dim=self.length_dim)
                max_input_length = mtf.reduce_max(input_length)
                decode_length = mtf.cast(
                    max_input_length * hparams.decode_length_multiplier +
                    hparams.decode_length_constant, tf.int32)
            else:
                decode_length = None
            beams, unused_scores = mtf.beam_search.beam_search(
                logits_fn,
                initial_ids,
                hparams.alpha,
                states=initial_states,
                decode_length=decode_length,
                use_tpu=hparams.use_tpu,
                dtype=self.activation_dtype)
            return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32),
                              beam_dim)
Ejemplo n.º 22
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    filter_h_dim = mtf.Dimension("filter_height", 3)
    filter_w_dim = mtf.Dimension("filter_width", 3)
    one_h_dim = mtf.Dimension("filter_height", 1)
    one_w_dim = mtf.Dimension("filter_width", 1)

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        kernel = mtf.get_variable(
            inputs.mesh, "kernel",
            mtf.Shape(
                [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
        shortcut = projection_shortcut(inputs, kernel)

    # First conv block
    filters1_dim = mtf.Dimension("filters1", filters)
    kernel1 = mtf.get_variable(
        inputs.mesh, "kernel1",
        mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel1,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    filters2_dim = mtf.Dimension("filters2", filters)
    kernel2 = mtf.get_variable(
        inputs.mesh, "kernel2",
        mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel2,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=row_blocks_dim,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    filters3_dim = mtf.Dimension("filters3", filters)
    filters3_kernel = mtf.get_variable(
        inputs.mesh, "wide_kernel",
        mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    filters3_kernel,
                                    strides,
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training, relu=False)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(inputs + mtf.rename_dimension(
        shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
Ejemplo n.º 23
0
  def _call_internal(self, context, inputs, targets=None):
    """Compute logits based on inputs (all positions in parallel).

    Also updates context if applicable.

    Args:
      context: a Context
      inputs: a Tensor
      targets: an optional Tensor

    Returns:
      logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim]
    """
    mesh = inputs.mesh
    if "embedding" in context.shared_params:
      embedding_weights = context.shared_params["embedding"]
    else:
      embedding_weights = mtf.layers.embedding_weights(
          mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype,
          name="embedding")
    x = mtf.gather(embedding_weights, inputs, self.input_vocab_dim)
    if "positional_embedding" in context.shared_params:
      pos_emb_var = context.shared_params["positional_embedding"]
    else:
      pos_emb_var = mtf.layers.embedding_weights(
          mesh, self.max_length_dim, self.model_dim, context.variable_dtype,
          "positional_embedding")
    if context.position_is_default:
      pos_emb = mtf.rename_dimension(
          mtf.slice(pos_emb_var, 0, context.length_dim.size,
                    self.max_length_dim.name),
          self.max_length_dim.name, context.length_dim.name)
    else:
      pos_emb = mtf.gather(
          pos_emb_var, context.position, self.max_length_dim,
          output_shape=x.shape)
    x += pos_emb
    x = self.layer_stack.call(context, x)
    if self.output_vocab_dim is None:
      return x
    if self.shared_embedding_and_softmax_weights:
      logits = mtf.einsum(
          [x * (self.model_dim.size ** -0.5), embedding_weights],
          reduced_dims=[self.model_dim])
    else:
      logits = mtf.layers.dense(
          x, self.output_vocab_dim, use_bias=False,
          variable_dtype=context.variable_dtype,
          name="logits")
    if targets is not None and context.losses is not None:
      off_value = self.label_smoothing / self.output_vocab_dim.size
      on_value = 1.0 - self.label_smoothing + off_value
      soft_targets = mtf.one_hot(
          targets, self.output_vocab_dim,
          dtype=context.activation_dtype,
          on_value=on_value,
          off_value=off_value)
      loss = mtf.layers.softmax_cross_entropy_with_logits(
          logits, soft_targets, self.output_vocab_dim,
          z_loss=self.z_loss if context.train else 0.0)
      weights = mtf.layers.weights_nonzero(
          targets, dtype=context.activation_dtype)
      loss = mtf.reduce_mean(loss * weights)
      context.losses.append(loss)
    return logits
Ejemplo n.º 24
0
  def _mtf_model_fn(self, features, mesh):
    self._original_features = features
    features = copy.copy(features)
    hparams = self._hparams
    extra_losses = []
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    if hparams.decoder_type == "autoregressive":
      shifted_targets = mtf.shift(
          targets, offset=1, dim=self.length_dim, wrap=False)
    elif hparams.decoder_type == "denoising":
      shifted_targets = self._noisy_targets(targets, extra_losses)
    else:
      raise ValueError(
          "unknown hparams.decoder_type = %s" % hparams.decoder_type)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = mtf.layers.attention_mask_same_segment(
          targets_segmentation, dtype=self.activation_dtype)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      if hparams.decoder_type == "autoregressive":
        decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive(
            targets_position, dtype=self.activation_dtype)
      else:
        decoder_self_attention_mask = None

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "decoder":
      encoder_output = None
      encoder_decoder_attention_mask = None
    else:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)

    if hparams.transformer_type == "encdec":
      if "inputs_segmentation" in features:
        encoder_decoder_attention_mask = (
            mtf.layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        encoder_decoder_attention_mask = encoder_self_attention_mask
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)

    if hparams.transformer_type != "encoder":
      # DECODER
      x = (mtf.gather(
          targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
           mtf.gather(
               positional_embedding_var, targets_position, self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("decoder"):
        x = self._layer_stack(
            x,
            hparams.decoder_layers,
            encoder_output=encoder_output,
            self_attention_mask=decoder_self_attention_mask,
            encdec_attention_mask=encoder_decoder_attention_mask,
            losses=extra_losses)
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      # For some reason, the logits computation is extremely slow on TPU
      # in some cases where the batch size per core is 1.  Reshape the logits
      # and the targets to double the batch size and halve the length.
      # TODO(noam): file a bug.
      old_dims = self.batch_dims + [self.length_dim]
      new_dims = self.batch_dims[:-1] + [
          mtf.Dimension(self.batch_dims[-1].name,
                        self.batch_dims[-1].size * 2),
          mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)]
      x = mtf.reshape(x, new_dims + [self.model_dim])
      targets = mtf.reshape(targets, new_dims)

    logits = mtf.matmul(x, softmax_var)
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    if (hparams.reshape_logits_hack and
        hparams.mode == tf.estimator.ModeKeys.TRAIN):
      logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim])
    logits = mtf.to_float(logits)
    return logits, loss