Esempio n. 1
0
def _relative_position_bucket(relative_position,
                              bidirectional=True,
                              num_buckets=32,
                              max_distance=128):
  """Translate relative position to a bucket number for relative attention.

  The relative position is defined as memory_position - query_position, i.e.
  the distance in tokens from the attending position to the attended-to
  position.  If bidirectional=False, then positive relative positions are
  invalid.

  We use smaller buckets for small absolute relative_position and larger buckets
  for larger absolute relative_positions.  All relative positions >=max_distance
  map to the same bucket.  All relative positions <=-max_distance map to the
  same bucket.  This should allow for more graceful generalization to longer
  sequences than the model has been trained on.

  Args:
    relative_position: an int32 Tensor
    bidirectional: a boolean - whether the attention is bidirectional
    num_buckets: an integer
    max_distance: an integer
  Returns:
    a Tensor with the same shape as relative_position, containing int32
      values in the range [0, num_buckets)
  """
  ret = 0
  n = -relative_position
  if bidirectional:
    num_buckets //= 2
    ret += mtf.to_int32(mtf.less(n, 0)) * num_buckets
    n = mtf.abs(n)
  else:
    n = mtf.maximum(n, 0)
  # now n is in the range [0, inf)
  max_exact = num_buckets // 2
  is_small = mtf.less(n, max_exact)
  val_if_large = max_exact + mtf.to_int32(
      mtf.log(mtf.to_float(n) / max_exact)
      / math.log(max_distance / max_exact) * (num_buckets - max_exact))
  val_if_large = mtf.minimum(val_if_large, num_buckets - 1)
  ret += mtf.where(is_small, n, val_if_large)
  return ret
Esempio n. 2
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")
    if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289
        return mtf.elu
    
    # swish activations
    elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941
        return mtf.swish
    
    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig": 
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid": 
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin": 
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh": 
        return lambda x: mtf.maximum(x, mtf.tanh(x))
    
    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x ** 2)
    elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x) ** 2
    else:
        raise ValueError('unknown activation function "activation_fn" in config')
Esempio n. 3
0
  def decode(self,
             inputs,
             variable_dtype=mtf.VariableDType(tf.float32),
             beam_size=1,
             alpha=0.6,
             temperature=0.0,
             decode_length_multiplier=1.5,
             decode_length_constant=10):
    """Sampling or beam search.

    TODO(noam): should we make the output length dimension different from the
    input length dimension?

    Args:
      inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim]
      variable_dtype: a mtf.VariableDType
      beam_size: an integer >= 1
      alpha: a floating point value (length bonus for beam search)
      temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      decode_length_multiplier: a float
      decode_length_constant: a float

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
    encoder_layer_outputs = []
    shared_params = self._shared_params(inputs.mesh, variable_dtype)
    encoder_sequence_id = mtf.minimum(inputs, 1)
    encoder_output, encoder_loss = self.encoder.call_simple(
        inputs=inputs,
        targets=None,
        compute_loss=False,
        mode=tf.estimator.ModeKeys.PREDICT,
        variable_dtype=variable_dtype,
        sequence_id=encoder_sequence_id,
        shared_params=shared_params,
        layer_outputs=encoder_layer_outputs)
    del encoder_loss
    encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output)
    encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
        encoder_sequence_id)
    if beam_size == 1:
      ids_shape = inputs.shape
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      return self.decoder.sample_autoregressive(
          partial_sequences,
          temperature=temperature,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          shared_params=shared_params,
          has_partial_sequences=False,
          encoder_layer_outputs=encoder_layer_outputs)
    else:
      if temperature != 0:
        raise ValueError(
            "don't know how to beam search with nonzero temperature")
      # beam search
      beam_dim = mtf.Dimension("beam", beam_size)
      batch_dims = inputs.shape[:-1]
      length_dim = inputs.shape[-1]
      ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim])
      partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32)
      input_length = mtf.reduce_sum(
          mtf.to_float(mtf.cast(inputs, tf.bool)),
          reduced_dim=length_dim)
      max_input_length = mtf.reduce_max(input_length)
      decode_length = mtf.cast(
          max_input_length * decode_length_multiplier
          + decode_length_constant, tf.int32)
      return self.decoder.beam_search(
          partial_sequences,
          decode_length,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          alpha=alpha,
          shared_params=shared_params,
          encoder_layer_outputs=encoder_layer_outputs)
def hybrid_attention(q,
                     k,
                     v,
                     context,
                     memory_length_dim,
                     key_dim,
                     value_dim,
                     bias=None,
                     dropout_rate=0.0,
                     dropout_broadcast_dims=None,
                     extra_logit=None):
  """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    context: context of the attention layer.
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
  logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim])
  if bias is not None:
    logits += bias

  query_length_dim = mtf.Dimension("length", memory_length_dim.size)
  doubly_coeff = mtf.get_variable(
      context.mesh, "doubly_coeff", [],
      initializer=tf.constant_initializer(0.5),
      dtype=context.variable_dtype)
  doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.)

  upper_weights = mtf.softmax(
      logits, memory_length_dim, extra_logit=extra_logit)

  lower_log_weights = mtf.log_softmax(
      logits, query_length_dim, extra_logit=extra_logit)
  doubly_weights = mtf.softmax(
      lower_log_weights, memory_length_dim, extra_logit=extra_logit)

  weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights
  weights = mtf.dropout(
      weights, context.train, 1.0 - dropout_rate,
      noise_shape=weights.shape - dropout_broadcast_dims)
  outputs_shape = q.shape - key_dim + value_dim
  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs
Esempio n. 5
0
    def decode(self,
               inputs,
               variable_dtype=mtf.VariableDType(tf.float32),
               beam_size=1,
               alpha=0.6,
               temperature=0.0,
               sampling_keep_top_k=-1,
               decode_length_multiplier=1.5,
               decode_length_constant=10,
               max_decode_length=None):
        """Sampling or beam search for Funnel Transformer.

    Args:
      inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim]
      variable_dtype: a mtf.VariableDType
      beam_size: an integer >= 1
      alpha: a floating point value (length bonus for beam search)
      temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      sampling_keep_top_k: a value between 1 and vocab_size used to sample from
        only the k most likely logits. Set to -1 to sample from all logits.
      decode_length_multiplier: a float
      decode_length_constant: a float
      max_decode_length: an optional integer

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
        encoder_layer_outputs = []
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_sequence_id = mtf.minimum(inputs, 1)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs=inputs,
            targets=None,
            compute_loss=False,
            mode=tf.estimator.ModeKeys.PREDICT,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            layer_outputs=encoder_layer_outputs)
        del encoder_loss
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)

        # The sequence_id is updated inside the layer_stack due to pooling. So we
        # need to use the updated sequence_id stored in the context.
        encoder_sequence_id = self.encoder.layer_stack.context.sequence_id
        encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
            encoder_sequence_id)
        batch_dims = inputs.shape[:-1]
        length_dim = inputs.shape[-1]
        if max_decode_length is None:
            decode_length_dim = length_dim
        else:
            decode_length_dim = mtf.Dimension("length", max_decode_length)
        if beam_size == 1:
            ids_shape = mtf.Shape(batch_dims + [decode_length_dim])
            partial_sequences = mtf.zeros(inputs.mesh,
                                          ids_shape,
                                          dtype=tf.int32)
            return self.decoder.sample_autoregressive(
                partial_sequences,
                temperature=temperature,
                sampling_keep_top_k=sampling_keep_top_k,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                encoder_inputs=mtf.layers.rename_length_to_memory_length(
                    inputs),
                shared_params=shared_params,
                has_partial_sequences=False,
                encoder_layer_outputs=encoder_layer_outputs)
        else:
            if temperature != 0:
                raise ValueError(
                    "don't know how to beam search with nonzero temperature")
            if sampling_keep_top_k != -1:
                raise ValueError(
                    "don't know how to beam search with top-k value other than -1."
                )
            # beam search
            beam_dim = mtf.Dimension("beam", beam_size)
            ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim])
            partial_sequences = mtf.zeros(inputs.mesh,
                                          ids_shape,
                                          dtype=tf.int32)
            input_length = mtf.reduce_sum(mtf.to_float(
                mtf.cast(inputs, tf.bool)),
                                          reduced_dim=length_dim)
            max_input_length = mtf.reduce_max(input_length)
            decode_length = mtf.cast(
                max_input_length * decode_length_multiplier +
                decode_length_constant, tf.int32)
            return self.decoder.beam_search(
                partial_sequences,
                decode_length,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                encoder_inputs=inputs,
                alpha=alpha,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs)
Esempio n. 6
0
    def call_simple(self,
                    inputs,
                    targets,
                    compute_loss,
                    mode=tf.estimator.ModeKeys.TRAIN,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_sequence_id=None,
                    decoder_sequence_id=None,
                    decoder_subsequence_id=None,
                    encoder_position=None,
                    decoder_position=None,
                    num_microbatches=1):
        """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    This class inherits the trnasformer.Bitransformer with one difference. The
    encoder is Funnel Transformer. So the length dimension is reduced. The
    decoder needs to use the updated `encoder_sequence_id`.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      encoder_sequence_id: an optional Tensor
      decoder_sequence_id: an optional Tensor
      decoder_subsequence_id: an optional Tensor
      encoder_position: an optional Tensor
      decoder_position: an optional Tensor
      num_microbatches: integer - greater than one if the step has been
        serialized into multiple microbatches to save memory.

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
        # encoder_sequene_id and decoder_sequence_id are used to delineate packed
        # examples but are also necessary to indicate padding where sequence_id==0.
        # If they are absent, then we assume that padding is indicated by zeros in
        # the inputs/targets, and we make up sequence_id tensors to indicate this.
        if encoder_sequence_id is None:
            encoder_sequence_id = mtf.minimum(inputs, 1)
        if decoder_sequence_id is None:
            decoder_sequence_id = mtf.minimum(targets, 1)
        encoder_layer_outputs = []
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs,
            None,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            position=encoder_position,
            shared_params=shared_params,
            layer_outputs=encoder_layer_outputs,
            num_microbatches=num_microbatches)
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)

        # The sequence_id is updated inside the layer_stack due to pooling. So we
        # need to use the updated sequence_id stored in the context.
        encoder_sequence_id = self.encoder.layer_stack.context.sequence_id
        encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
            encoder_sequence_id)

        logits, loss = self.decoder.call_simple(
            transformer.autoregressive_inputs(targets,
                                              sequence_id=decoder_sequence_id),
            targets,
            compute_loss,
            mode=mode,
            variable_dtype=variable_dtype,
            sequence_id=decoder_sequence_id,
            subsequence_id=decoder_subsequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs),
            position=decoder_position,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            num_microbatches=num_microbatches)
        if loss is not None and encoder_loss is not None:
            loss += encoder_loss
        return logits, loss
Esempio n. 7
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")

    def _arcsinh(x):
        return mtf.log(x + mtf.sqrt(1 + x**2))

    def _var(x, init):
        return mtf.get_variable(x.mesh,
                                f"activation-{random.randint(0, 2 ** 32):x}",
                                [],
                                initializer=tf.constant_initializer(init),
                                dtype=x.dtype)

    def _pos_var(x, val):
        return mtf.softplus(_var(x, 0)) + val

    if activation_fn == "gelu":  # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu":  # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu":  # https://arxiv.org/abs/1511.07289
        return mtf.elu
    elif activation_fn == "lrelu001":
        return lambda x: mtf.leaky_relu(x, alpha=0.01)
    elif activation_fn == "lrelu020":
        return lambda x: mtf.leaky_relu(x, alpha=0.20)

    elif activation_fn == "abs":
        return mtf.abs
    elif activation_fn == "id":
        return lambda x: x
    elif activation_fn == "sin":
        return mtf.sin
    elif activation_fn == "cos":
        return mtf.cos
    elif activation_fn == "sign":
        return mtf.sign
    elif activation_fn == "triangle_relax":
        return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(
            5 * x) / 25 - mtf.sin(7 * x) / 49
    elif activation_fn == "square_relax":
        return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(
            5 * x) / 5 - mtf.cos(7 * x) / 7
    elif activation_fn == "spike":
        return lambda x: 1 / (1 + x**2)
    elif activation_fn == "spike2":
        return lambda x: mtf.exp(-x**2)

    elif activation_fn == "tanhshrink":
        return lambda x: x - tanh(x)
    elif activation_fn == "softsign":
        return lambda x: x / (mtf.abs(x) + 1)
    elif activation_fn == "softmax":
        return lambda x: mtf.softmax(x, x.shape[-1])
    elif activation_fn == "logsoftmax":
        return lambda x: mtf.log_softmax(x, x.shape[-1])
    elif activation_fn == "bipolarsigmoid":
        return lambda x: mtf.sigmoid(x) * 2 - 1
    elif activation_fn == "rrelu":  # https://arxiv.org/abs/1505.00853

        def _rrelu_fn(x):
            negative_scale = random.random()
            return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)

        return _rrelu_fn
    elif activation_fn == "elish":  # https://arxiv.org/abs/1808.00783v1

        def _elish_fn(x):
            cond = mtf.cast(mtf.greater(x, 0), x.dtype)
            exp = mtf.exp(x)
            return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp +
                                                                    1)

        return _elish_fn

    elif activation_fn == "silu":  # https://arxiv.org/abs/1710.05941
        return mtf.swish

    elif activation_fn == "arcsinh":
        return _arcsinh

    # parametric
    elif activation_fn == "aria":  # https://arxiv.org/abs/1805.08878
        return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var(
            x, 1) * mtf.exp(_var(x, -1) * x)**(1 / _pos_var(x, 1))))
    elif activation_fn == "prelu":  # https://arxiv.org/abs/1502.01852
        return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2))
    elif activation_fn == "parcsinh":
        return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1))
    elif activation_fn == "psoftplus":
        return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0)
    elif activation_fn == "proottanh":
        return lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var(
            x, 3)) * mtf.tanh(x)

    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig":
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid":
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin":
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh":
        return lambda x: mtf.maximum(x, mtf.tanh(x))

    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish":  # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp":  # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht":  # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull":  # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x**2)
    elif activation_fn == "snake":  # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x)**2

    elif activation_fn == "roottanh":  # made up
        return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x)
    elif activation_fn == "softplusmone":  # made up
        return lambda x: mtf.softplus(x) - 1

    else:
        raise ValueError(
            'unknown activation function "activation_fn" in config')
Esempio n. 8
0
def get_activation_fn(params):
    activation_fn = params.get("activation_fn", "gelu")
    if activation_fn == "gelu":  # https://arxiv.org/abs/1606.08415
        return mtf.gelu
    elif activation_fn == "relu":
        return mtf.relu
    elif activation_fn == "sigmoid":
        return mtf.sigmoid
    elif activation_fn == "tanh":
        return mtf.tanh
    elif activation_fn == "selu":  # https://arxiv.org/abs/1706.02515
        return mtf.selu
    elif activation_fn == "elu":  # https://arxiv.org/abs/1511.07289
        return mtf.elu

    elif activation_fn == "abs":
        return mtf.abs
    elif activation_fn == "id":
        return lambda x: x
    elif activation_fn == "sin":
        return mtf.sin
    elif activation_fn == "cos":
        return mtf.cos
    elif activation_fn == "sign":
        return mtf.sign
    elif activation_fn == "triangle_relax":
        return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(
            5 * x) / 25 - mtf.sin(7 * x) / 49
    elif activation_fn == "square_relax":
        return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(
            5 * x) / 5 - mtf.cos(7 * x) / 7
    elif activation_fn == "spike":
        return lambda x: 1 / (1 + x**2)
    elif activation_fn == "spike2":
        return lambda x: mtf.exp(-x**2)

    elif activation_fn == "tanhshrink":
        return lambda x: x - tanh(x)
    elif activation_fn == "softsign":
        return lambda x: x / (mtf.abs(x) + 1)
    elif activation_fn == "softmax":
        return lambda x: mtf.softmax(x, x.shape[-1])
    elif activation_fn == "logsoftmax":
        return lambda x: mtf.log_softmax(x, x.shape[-1])
    elif activation_fn == "bipolarsigmoid":
        return lambda x: mtf.sigmoid(x) * 2 - 1
    elif activation_fn == "rrelu":  # https://arxiv.org/abs/1505.00853

        def _rrelu_fn(x):
            negative_scale = random.random()
            return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)

        return _rrelu_fn
    elif activation_fn == "elish":  # https://arxiv.org/abs/1808.00783v1

        def _elish_fn(x):
            cond = mtf.cast(mtf.greater(x, 0), x.dtype)
            exp = mtf.exp(x)
            return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp +
                                                                    1)

        return _elish_fn

    # swish activations
    elif activation_fn == "swish":  # https://arxiv.org/abs/1710.05941
        return mtf.swish

    # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671
    elif activation_fn == "maxsig":
        return lambda x: mtf.maximum(x, mtf.sigmoid(x))
    elif activation_fn == "cosid":
        return lambda x: mtf.cos(x) - x
    elif activation_fn == "minsin":
        return lambda x: mtf.minimum(x, mtf.sin(x))
    elif activation_fn == "maxtanh":
        return lambda x: mtf.maximum(x, mtf.tanh(x))

    elif activation_fn == "softplus":
        return mtf.softplus
    elif activation_fn == "mish":  # https://arxiv.org/abs/1908.08681
        return lambda x: x * mtf.tanh(mtf.softplus(x))
    elif activation_fn == "tanhexp":  # https://arxiv.org/abs/2003.09855
        return lambda x: x * mtf.tanh(mtf.exp(x))
    elif activation_fn == "lisht":  # https://arxiv.org/abs/1901.05894
        return lambda x: x * mtf.tanh(x)
    elif activation_fn == "seagull":  # https://arxiv.org/abs/2011.11713
        return lambda x: mtf.log(1 + x**2)
    elif activation_fn == "snake":  # https://arxiv.org/abs/2006.08195
        return lambda x: x + mtf.sin(x)**2

    elif activation_fn == "roottanh":  # made up
        return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x)
    elif activation_fn == "softplusmone":  # made up
        return lambda x: mtf.softplus(x) - 1

    else:
        raise ValueError(
            'unknown activation function "activation_fn" in config')
Esempio n. 9
0
                 (1 / _pos_var(x, 1)))),
 'prelu':
 lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)),
 'parcsinh':
 lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)),
 'psoftplus':
 lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0),
 'proottanh':
 lambda x:
 (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var(x, 3)) * mtf.tanh(x),
 'maxsig':
 lambda x: mtf.maximum(x, mtf.sigmoid(x)),
 'cosid':
 lambda x: mtf.cos(x) - x,
 'minsin':
 lambda x: mtf.minimum(x, mtf.sin(x)),
 'maxtanh':
 lambda x: mtf.maximum(x, mtf.tanh(x)),
 'mish':
 lambda x: x * mtf.tanh(mtf.softplus(x)),
 'tanhexp':
 lambda x: x * mtf.tanh(mtf.exp(x)),
 'lisht':
 lambda x: x * mtf.tanh(x),
 'seagull':
 lambda x: mtf.log(1 + x**2),
 'snake':
 lambda x: x + mtf.sin(x)**2,
 'roottanh':
 lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x),
 'softplusmone':