Пример #1
0
  def get_log_softmax_prefix(self, log_softmax, end_index):
    """Returns first end_index entries in log_softmax along the vocab dim."""
    prefix_dim = mtf.Dimension(self._vocab_dim.name, end_index)

    indices = mtf.mtf_range(
        log_softmax.mesh, dim=self._vocab_dim, dtype=tf.int32)
    prefix_indices = mtf.where(mtf.less(indices, end_index), indices, -1)
    projection = mtf.one_hot(
        prefix_indices, prefix_dim, dtype=log_softmax.dtype)

    return mtf.einsum([log_softmax, projection], reduced_dims=[self._vocab_dim])
    def get_timing_signal_1d(self,
                             context,
                             length,
                             channels,
                             min_timescale=1.0,
                             max_timescale=1.0e4,
                             start_index=0):
        """Gets a bunch of sinusoids of different frequencies.

    Each channel of the input Tensor is incremented by a sinusoid of a different
    frequency and phase.

    This allows attention to learn to use absolute and relative positions.
    Timing signals should be added to some precursors of both the query and the
    memory inputs to attention.

    The use of relative position is possible because sin(x+y) and cos(x+y) can
    be expressed in terms of y, sin(x) and cos(x).

    In particular, we use a geometric sequence of timescales starting with
    min_timescale and ending with max_timescale.  The number of different
    timescales is equal to channels / 2. For each timescale, we
    generate the two sinusoidal signals sin(timestep/timescale) and
    cos(timestep/timescale).  All of these sinusoids are concatenated in
    the channels dimension.

    Args:
      context: mtf context.
      length: a mtf.Dimension, length of timing signal sequence.
      channels: a mtf.Dimension, size of timing embeddings to create.
      The number of different timescales is equal to channels / 2.
      min_timescale: a float
      max_timescale: a float
      start_index: index of first position

    Returns:
      a Tensor of timing signals [1, length, channels]
    """

        position = context.get_position() + start_index
        num_timescales = mtf.constant(context.mesh, channels.size // 2)
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            mtf.maximum(num_timescales - 1, 1))
        channel_dim_name = channels.name
        inv_timescales = (min_timescale * mtf.exp(
            mtf.mtf_range(context.mesh,
                          mtf.Dimension(channel_dim_name, channels.size // 2),
                          context.activation_dtype) * -log_timescale_increment)
                          )

        scaled_time = position * inv_timescales
        # Please note that this slightly differs from the published paper.
        # See a discussion here:
        # https://github.com/tensorflow/tensor2tensor/pull/177
        #    concat_dim_name = scaled_time.shape.dimension_names[1]
        concat_dim_name = channels.name
        signal = mtf.concat(
            [mtf.sin(scaled_time), mtf.cos(scaled_time)],
            concat_dim_name=concat_dim_name)

        if channels.size % 2 != 0:
            raise NotImplementedError("Odd channel size not implemented.")
        new_dims = [mtf.Dimension("expanded", 1)
                    ] + length.shape.dims + channels.shape.dim
        signal = mtf.reshape(signal, mtf.Shape(new_dims))
        return signal