Esempio n. 1
0
def clip_by_global_norm(grads, clip_norm):
    """Clip the grads by global norm."""
    global_norm = mtf.sqrt(
        mtf.add_n(
            [mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
    multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
    clipped_grads = [None if t is None else t * multiplier for t in grads]
    return clipped_grads, global_norm
Esempio n. 2
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")
    if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289
        return mtf.elu
    
    # swish activations
    elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941
        return mtf.swish
    
    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig": 
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid": 
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin": 
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh": 
        return lambda x: mtf.maximum(x, mtf.tanh(x))
    
    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x ** 2)
    elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x) ** 2
    else:
        raise ValueError('unknown activation function "activation_fn" in config')
Esempio n. 3
0
def _relative_position_bucket(relative_position,
                              bidirectional=True,
                              num_buckets=32,
                              max_distance=128):
  """Translate relative position to a bucket number for relative attention.

  The relative position is defined as memory_position - query_position, i.e.
  the distance in tokens from the attending position to the attended-to
  position.  If bidirectional=False, then positive relative positions are
  invalid.

  We use smaller buckets for small absolute relative_position and larger buckets
  for larger absolute relative_positions.  All relative positions >=max_distance
  map to the same bucket.  All relative positions <=-max_distance map to the
  same bucket.  This should allow for more graceful generalization to longer
  sequences than the model has been trained on.

  Args:
    relative_position: an int32 Tensor
    bidirectional: a boolean - whether the attention is bidirectional
    num_buckets: an integer
    max_distance: an integer
  Returns:
    a Tensor with the same shape as relative_position, containing int32
      values in the range [0, num_buckets)
  """
  ret = 0
  n = -relative_position
  if bidirectional:
    num_buckets //= 2
    ret += mtf.to_int32(mtf.less(n, 0)) * num_buckets
    n = mtf.abs(n)
  else:
    n = mtf.maximum(n, 0)
  # now n is in the range [0, inf)
  max_exact = num_buckets // 2
  is_small = mtf.less(n, max_exact)
  val_if_large = max_exact + mtf.to_int32(
      mtf.log(mtf.to_float(n) / max_exact)
      / math.log(max_distance / max_exact) * (num_buckets - max_exact))
  val_if_large = mtf.minimum(val_if_large, num_buckets - 1)
  ret += mtf.where(is_small, n, val_if_large)
  return ret
def hybrid_attention(q,
                     k,
                     v,
                     context,
                     memory_length_dim,
                     key_dim,
                     value_dim,
                     bias=None,
                     dropout_rate=0.0,
                     dropout_broadcast_dims=None,
                     extra_logit=None):
  """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    context: context of the attention layer.
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
  logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim])
  if bias is not None:
    logits += bias

  query_length_dim = mtf.Dimension("length", memory_length_dim.size)
  doubly_coeff = mtf.get_variable(
      context.mesh, "doubly_coeff", [],
      initializer=tf.constant_initializer(0.5),
      dtype=context.variable_dtype)
  doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.)

  upper_weights = mtf.softmax(
      logits, memory_length_dim, extra_logit=extra_logit)

  lower_log_weights = mtf.log_softmax(
      logits, query_length_dim, extra_logit=extra_logit)
  doubly_weights = mtf.softmax(
      lower_log_weights, memory_length_dim, extra_logit=extra_logit)

  weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights
  weights = mtf.dropout(
      weights, context.train, 1.0 - dropout_rate,
      noise_shape=weights.shape - dropout_broadcast_dims)
  outputs_shape = q.shape - key_dim + value_dim
  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs
    def get_timing_signal_1d(self,
                             context,
                             length,
                             channels,
                             min_timescale=1.0,
                             max_timescale=1.0e4,
                             start_index=0):
        """Gets a bunch of sinusoids of different frequencies.

    Each channel of the input Tensor is incremented by a sinusoid of a different
    frequency and phase.

    This allows attention to learn to use absolute and relative positions.
    Timing signals should be added to some precursors of both the query and the
    memory inputs to attention.

    The use of relative position is possible because sin(x+y) and cos(x+y) can
    be expressed in terms of y, sin(x) and cos(x).

    In particular, we use a geometric sequence of timescales starting with
    min_timescale and ending with max_timescale.  The number of different
    timescales is equal to channels / 2. For each timescale, we
    generate the two sinusoidal signals sin(timestep/timescale) and
    cos(timestep/timescale).  All of these sinusoids are concatenated in
    the channels dimension.

    Args:
      context: mtf context.
      length: a mtf.Dimension, length of timing signal sequence.
      channels: a mtf.Dimension, size of timing embeddings to create.
      The number of different timescales is equal to channels / 2.
      min_timescale: a float
      max_timescale: a float
      start_index: index of first position

    Returns:
      a Tensor of timing signals [1, length, channels]
    """

        position = context.get_position() + start_index
        num_timescales = mtf.constant(context.mesh, channels.size // 2)
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            mtf.maximum(num_timescales - 1, 1))
        channel_dim_name = channels.name
        inv_timescales = (min_timescale * mtf.exp(
            mtf.mtf_range(context.mesh,
                          mtf.Dimension(channel_dim_name, channels.size // 2),
                          context.activation_dtype) * -log_timescale_increment)
                          )

        scaled_time = position * inv_timescales
        # Please note that this slightly differs from the published paper.
        # See a discussion here:
        # https://github.com/tensorflow/tensor2tensor/pull/177
        #    concat_dim_name = scaled_time.shape.dimension_names[1]
        concat_dim_name = channels.name
        signal = mtf.concat(
            [mtf.sin(scaled_time), mtf.cos(scaled_time)],
            concat_dim_name=concat_dim_name)

        if channels.size % 2 != 0:
            raise NotImplementedError("Odd channel size not implemented.")
        new_dims = [mtf.Dimension("expanded", 1)
                    ] + length.shape.dims + channels.shape.dim
        signal = mtf.reshape(signal, mtf.Shape(new_dims))
        return signal
Esempio n. 6
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")

    def _arcsinh(x):
        return mtf.log(x + mtf.sqrt(1 + x**2))

    def _var(x, init):
        return mtf.get_variable(x.mesh,
                                f"activation-{random.randint(0, 2 ** 32):x}",
                                [],
                                initializer=tf.constant_initializer(init),
                                dtype=x.dtype)

    def _pos_var(x, val):
        return mtf.softplus(_var(x, 0)) + val

    if activation_fn == "gelu":  # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu":  # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu":  # https://arxiv.org/abs/1511.07289
        return mtf.elu
    elif activation_fn == "lrelu001":
        return lambda x: mtf.leaky_relu(x, alpha=0.01)
    elif activation_fn == "lrelu020":
        return lambda x: mtf.leaky_relu(x, alpha=0.20)

    elif activation_fn == "abs":
        return mtf.abs
    elif activation_fn == "id":
        return lambda x: x
    elif activation_fn == "sin":
        return mtf.sin
    elif activation_fn == "cos":
        return mtf.cos
    elif activation_fn == "sign":
        return mtf.sign
    elif activation_fn == "triangle_relax":
        return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(
            5 * x) / 25 - mtf.sin(7 * x) / 49
    elif activation_fn == "square_relax":
        return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(
            5 * x) / 5 - mtf.cos(7 * x) / 7
    elif activation_fn == "spike":
        return lambda x: 1 / (1 + x**2)
    elif activation_fn == "spike2":
        return lambda x: mtf.exp(-x**2)

    elif activation_fn == "tanhshrink":
        return lambda x: x - tanh(x)
    elif activation_fn == "softsign":
        return lambda x: x / (mtf.abs(x) + 1)
    elif activation_fn == "softmax":
        return lambda x: mtf.softmax(x, x.shape[-1])
    elif activation_fn == "logsoftmax":
        return lambda x: mtf.log_softmax(x, x.shape[-1])
    elif activation_fn == "bipolarsigmoid":
        return lambda x: mtf.sigmoid(x) * 2 - 1
    elif activation_fn == "rrelu":  # https://arxiv.org/abs/1505.00853

        def _rrelu_fn(x):
            negative_scale = random.random()
            return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)

        return _rrelu_fn
    elif activation_fn == "elish":  # https://arxiv.org/abs/1808.00783v1

        def _elish_fn(x):
            cond = mtf.cast(mtf.greater(x, 0), x.dtype)
            exp = mtf.exp(x)
            return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp +
                                                                    1)

        return _elish_fn

    elif activation_fn == "silu":  # https://arxiv.org/abs/1710.05941
        return mtf.swish

    elif activation_fn == "arcsinh":
        return _arcsinh

    # parametric
    elif activation_fn == "aria":  # https://arxiv.org/abs/1805.08878
        return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var(
            x, 1) * mtf.exp(_var(x, -1) * x)**(1 / _pos_var(x, 1))))
    elif activation_fn == "prelu":  # https://arxiv.org/abs/1502.01852
        return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2))
    elif activation_fn == "parcsinh":
        return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1))
    elif activation_fn == "psoftplus":
        return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0)
    elif activation_fn == "proottanh":
        return lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var(
            x, 3)) * mtf.tanh(x)

    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig":
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid":
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin":
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh":
        return lambda x: mtf.maximum(x, mtf.tanh(x))

    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish":  # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp":  # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht":  # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull":  # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x**2)
    elif activation_fn == "snake":  # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x)**2

    elif activation_fn == "roottanh":  # made up
        return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x)
    elif activation_fn == "softplusmone":  # made up
        return lambda x: mtf.softplus(x) - 1

    else:
        raise ValueError(
            'unknown activation function "activation_fn" in config')
Esempio n. 7
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")
    if activation_fn == "gelu":  # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu":  # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu":  # https://arxiv.org/abs/1511.07289
        return mtf.elu

    elif activation_fn == "abs":
        return mtf.abs
    elif activation_fn == "id":
        return lambda x: x
    elif activation_fn == "sin":
        return mtf.sin
    elif activation_fn == "cos":
        return mtf.cos
    elif activation_fn == "sign":
        return mtf.sign
    elif activation_fn == "triangle_relax":
        return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(
            5 * x) / 25 - mtf.sin(7 * x) / 49
    elif activation_fn == "square_relax":
        return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(
            5 * x) / 5 - mtf.cos(7 * x) / 7
    elif activation_fn == "spike":
        return lambda x: 1 / (1 + x**2)
    elif activation_fn == "spike2":
        return lambda x: mtf.exp(-x**2)

    elif activation_fn == "tanhshrink":
        return lambda x: x - tanh(x)
    elif activation_fn == "softsign":
        return lambda x: x / (mtf.abs(x) + 1)
    elif activation_fn == "softmax":
        return lambda x: mtf.softmax(x, x.shape[-1])
    elif activation_fn == "logsoftmax":
        return lambda x: mtf.log_softmax(x, x.shape[-1])
    elif activation_fn == "bipolarsigmoid":
        return lambda x: mtf.sigmoid(x) * 2 - 1
    elif activation_fn == "rrelu":  # https://arxiv.org/abs/1505.00853

        def _rrelu_fn(x):
            negative_scale = random.random()
            return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)

        return _rrelu_fn
    elif activation_fn == "elish":  # https://arxiv.org/abs/1808.00783v1

        def _elish_fn(x):
            cond = mtf.cast(mtf.greater(x, 0), x.dtype)
            exp = mtf.exp(x)
            return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp +
                                                                    1)

        return _elish_fn

    # swish activations
    elif activation_fn == "swish":  # https://arxiv.org/abs/1710.05941
        return mtf.swish

    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig":
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid":
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin":
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh":
        return lambda x: mtf.maximum(x, mtf.tanh(x))

    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish":  # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp":  # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht":  # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull":  # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x**2)
    elif activation_fn == "snake":  # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x)**2

    elif activation_fn == "roottanh":  # made up
        return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x)
    elif activation_fn == "softplusmone":  # made up
        return lambda x: mtf.softplus(x) - 1

    else:
        raise ValueError(
            'unknown activation function "activation_fn" in config')
Esempio n. 8
0
 _arcsinh,
 'aria':
 lambda x: x * (_var(x, 0) + _var(x, 1) /
                (_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x)**
                 (1 / _pos_var(x, 1)))),
 'prelu':
 lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)),
 'parcsinh':
 lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)),
 'psoftplus':
 lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0),
 'proottanh':
 lambda x:
 (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var(x, 3)) * mtf.tanh(x),
 'maxsig':
 lambda x: mtf.maximum(x, mtf.sigmoid(x)),
 'cosid':
 lambda x: mtf.cos(x) - x,
 'minsin':
 lambda x: mtf.minimum(x, mtf.sin(x)),
 'maxtanh':
 lambda x: mtf.maximum(x, mtf.tanh(x)),
 'mish':
 lambda x: x * mtf.tanh(mtf.softplus(x)),
 'tanhexp':
 lambda x: x * mtf.tanh(mtf.exp(x)),
 'lisht':
 lambda x: x * mtf.tanh(x),
 'seagull':
 lambda x: mtf.log(1 + x**2),
 'snake':