def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on. Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ ret = 0 n = -relative_position if bidirectional: num_buckets //= 2 ret += mtf.to_int32(mtf.less(n, 0)) * num_buckets n = mtf.abs(n) else: n = mtf.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = mtf.less(n, max_exact) val_if_large = max_exact + mtf.to_int32( mtf.log(mtf.to_float(n) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)) val_if_large = mtf.minimum(val_if_large, num_buckets - 1) ret += mtf.where(is_small, n, val_if_large) return ret
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu # swish activations elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941 return mtf.swish # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x ** 2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x) ** 2 else: raise ValueError('unknown activation function "activation_fn" in config')
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, decode_length_multiplier=1.5, decode_length_constant=10): """Sampling or beam search. TODO(noam): should we make the output length dimension different from the input length dimension? Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. decode_length_multiplier: a float decode_length_constant: a float Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output) encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) if beam_size == 1: ids_shape = inputs.shape partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") # beam search beam_dim = mtf.Dimension("beam", beam_size) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def hybrid_attention(q, k, v, context, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor context: context of the attention layer. memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias query_length_dim = mtf.Dimension("length", memory_length_dim.size) doubly_coeff = mtf.get_variable( context.mesh, "doubly_coeff", [], initializer=tf.constant_initializer(0.5), dtype=context.variable_dtype) doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.) upper_weights = mtf.softmax( logits, memory_length_dim, extra_logit=extra_logit) lower_log_weights = mtf.log_softmax( logits, query_length_dim, extra_logit=extra_logit) doubly_weights = mtf.softmax( lower_log_weights, memory_length_dim, extra_logit=extra_logit) weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, sampling_keep_top_k=-1, decode_length_multiplier=1.5, decode_length_constant=10, max_decode_length=None): """Sampling or beam search for Funnel Transformer. Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sampling_keep_top_k: a value between 1 and vocab_size used to sample from only the k most likely logits. Set to -1 to sample from all logits. decode_length_multiplier: a float decode_length_constant: a float max_decode_length: an optional integer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) # The sequence_id is updated inside the layer_stack due to pooling. So we # need to use the updated sequence_id stored in the context. encoder_sequence_id = self.encoder.layer_stack.context.sequence_id encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] if max_decode_length is None: decode_length_dim = length_dim else: decode_length_dim = mtf.Dimension("length", max_decode_length) if beam_size == 1: ids_shape = mtf.Shape(batch_dims + [decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, sampling_keep_top_k=sampling_keep_top_k, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length( inputs), shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") if sampling_keep_top_k != -1: raise ValueError( "don't know how to beam search with top-k value other than -1." ) # beam search beam_dim = mtf.Dimension("beam", beam_size) ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=inputs, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs)
def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), encoder_sequence_id=None, decoder_sequence_id=None, decoder_subsequence_id=None, encoder_position=None, decoder_position=None, num_microbatches=1): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. This class inherits the trnasformer.Bitransformer with one difference. The encoder is Funnel Transformer. So the length dimension is reduced. The decoder needs to use the updated `encoder_sequence_id`. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType encoder_sequence_id: an optional Tensor decoder_sequence_id: an optional Tensor decoder_subsequence_id: an optional Tensor encoder_position: an optional Tensor decoder_position: an optional Tensor num_microbatches: integer - greater than one if the step has been serialized into multiple microbatches to save memory. Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ # encoder_sequene_id and decoder_sequence_id are used to delineate packed # examples but are also necessary to indicate padding where sequence_id==0. # If they are absent, then we assume that padding is indicated by zeros in # the inputs/targets, and we make up sequence_id tensors to indicate this. if encoder_sequence_id is None: encoder_sequence_id = mtf.minimum(inputs, 1) if decoder_sequence_id is None: decoder_sequence_id = mtf.minimum(targets, 1) encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_output, encoder_loss = self.encoder.call_simple( inputs, None, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, position=encoder_position, shared_params=shared_params, layer_outputs=encoder_layer_outputs, num_microbatches=num_microbatches) encoder_output = mtf.layers.rename_length_to_memory_length( encoder_output) # The sequence_id is updated inside the layer_stack due to pooling. So we # need to use the updated sequence_id stored in the context. encoder_sequence_id = self.encoder.layer_stack.context.sequence_id encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) logits, loss = self.decoder.call_simple( transformer.autoregressive_inputs(targets, sequence_id=decoder_sequence_id), targets, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=decoder_sequence_id, subsequence_id=decoder_subsequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs), position=decoder_position, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, num_microbatches=num_microbatches) if loss is not None and encoder_loss is not None: loss += encoder_loss return logits, loss
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") def _arcsinh(x): return mtf.log(x + mtf.sqrt(1 + x**2)) def _var(x, init): return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], initializer=tf.constant_initializer(init), dtype=x.dtype) def _pos_var(x, val): return mtf.softplus(_var(x, 0)) + val if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "lrelu001": return lambda x: mtf.leaky_relu(x, alpha=0.01) elif activation_fn == "lrelu020": return lambda x: mtf.leaky_relu(x, alpha=0.20) elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn elif activation_fn == "silu": # https://arxiv.org/abs/1710.05941 return mtf.swish elif activation_fn == "arcsinh": return _arcsinh # parametric elif activation_fn == "aria": # https://arxiv.org/abs/1805.08878 return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var( x, 1) * mtf.exp(_var(x, -1) * x)**(1 / _pos_var(x, 1)))) elif activation_fn == "prelu": # https://arxiv.org/abs/1502.01852 return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)) elif activation_fn == "parcsinh": return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)) elif activation_fn == "psoftplus": return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0) elif activation_fn == "proottanh": return lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var( x, 3)) * mtf.tanh(x) # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn # swish activations elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941 return mtf.swish # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
(1 / _pos_var(x, 1)))), 'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)), 'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)), 'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0), 'proottanh': lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var(x, 3)) * mtf.tanh(x), 'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)), 'cosid': lambda x: mtf.cos(x) - x, 'minsin': lambda x: mtf.minimum(x, mtf.sin(x)), 'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)), 'mish': lambda x: x * mtf.tanh(mtf.softplus(x)), 'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)), 'lisht': lambda x: x * mtf.tanh(x), 'seagull': lambda x: mtf.log(1 + x**2), 'snake': lambda x: x + mtf.sin(x)**2, 'roottanh': lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x), 'softplusmone':