def clip_by_global_norm(grads, clip_norm): """Clip the grads by global norm.""" global_norm = mtf.sqrt( mtf.add_n( [mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None])) multiplier = clip_norm / mtf.maximum(global_norm, clip_norm) clipped_grads = [None if t is None else t * multiplier for t in grads] return clipped_grads, global_norm
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu # swish activations elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941 return mtf.swish # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x ** 2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x) ** 2 else: raise ValueError('unknown activation function "activation_fn" in config')
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on. Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ ret = 0 n = -relative_position if bidirectional: num_buckets //= 2 ret += mtf.to_int32(mtf.less(n, 0)) * num_buckets n = mtf.abs(n) else: n = mtf.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = mtf.less(n, max_exact) val_if_large = max_exact + mtf.to_int32( mtf.log(mtf.to_float(n) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)) val_if_large = mtf.minimum(val_if_large, num_buckets - 1) ret += mtf.where(is_small, n, val_if_large) return ret
def hybrid_attention(q, k, v, context, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor context: context of the attention layer. memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias query_length_dim = mtf.Dimension("length", memory_length_dim.size) doubly_coeff = mtf.get_variable( context.mesh, "doubly_coeff", [], initializer=tf.constant_initializer(0.5), dtype=context.variable_dtype) doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.) upper_weights = mtf.softmax( logits, memory_length_dim, extra_logit=extra_logit) lower_log_weights = mtf.log_softmax( logits, query_length_dim, extra_logit=extra_logit) doubly_weights = mtf.softmax( lower_log_weights, memory_length_dim, extra_logit=extra_logit) weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def get_timing_signal_1d(self, context, length, channels, min_timescale=1.0, max_timescale=1.0e4, start_index=0): """Gets a bunch of sinusoids of different frequencies. Each channel of the input Tensor is incremented by a sinusoid of a different frequency and phase. This allows attention to learn to use absolute and relative positions. Timing signals should be added to some precursors of both the query and the memory inputs to attention. The use of relative position is possible because sin(x+y) and cos(x+y) can be expressed in terms of y, sin(x) and cos(x). In particular, we use a geometric sequence of timescales starting with min_timescale and ending with max_timescale. The number of different timescales is equal to channels / 2. For each timescale, we generate the two sinusoidal signals sin(timestep/timescale) and cos(timestep/timescale). All of these sinusoids are concatenated in the channels dimension. Args: context: mtf context. length: a mtf.Dimension, length of timing signal sequence. channels: a mtf.Dimension, size of timing embeddings to create. The number of different timescales is equal to channels / 2. min_timescale: a float max_timescale: a float start_index: index of first position Returns: a Tensor of timing signals [1, length, channels] """ position = context.get_position() + start_index num_timescales = mtf.constant(context.mesh, channels.size // 2) log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / mtf.maximum(num_timescales - 1, 1)) channel_dim_name = channels.name inv_timescales = (min_timescale * mtf.exp( mtf.mtf_range(context.mesh, mtf.Dimension(channel_dim_name, channels.size // 2), context.activation_dtype) * -log_timescale_increment) ) scaled_time = position * inv_timescales # Please note that this slightly differs from the published paper. # See a discussion here: # https://github.com/tensorflow/tensor2tensor/pull/177 # concat_dim_name = scaled_time.shape.dimension_names[1] concat_dim_name = channels.name signal = mtf.concat( [mtf.sin(scaled_time), mtf.cos(scaled_time)], concat_dim_name=concat_dim_name) if channels.size % 2 != 0: raise NotImplementedError("Odd channel size not implemented.") new_dims = [mtf.Dimension("expanded", 1) ] + length.shape.dims + channels.shape.dim signal = mtf.reshape(signal, mtf.Shape(new_dims)) return signal
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") def _arcsinh(x): return mtf.log(x + mtf.sqrt(1 + x**2)) def _var(x, init): return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], initializer=tf.constant_initializer(init), dtype=x.dtype) def _pos_var(x, val): return mtf.softplus(_var(x, 0)) + val if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "lrelu001": return lambda x: mtf.leaky_relu(x, alpha=0.01) elif activation_fn == "lrelu020": return lambda x: mtf.leaky_relu(x, alpha=0.20) elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn elif activation_fn == "silu": # https://arxiv.org/abs/1710.05941 return mtf.swish elif activation_fn == "arcsinh": return _arcsinh # parametric elif activation_fn == "aria": # https://arxiv.org/abs/1805.08878 return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var( x, 1) * mtf.exp(_var(x, -1) * x)**(1 / _pos_var(x, 1)))) elif activation_fn == "prelu": # https://arxiv.org/abs/1502.01852 return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)) elif activation_fn == "parcsinh": return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)) elif activation_fn == "psoftplus": return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0) elif activation_fn == "proottanh": return lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var( x, 3)) * mtf.tanh(x) # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn # swish activations elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941 return mtf.swish # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
_arcsinh, 'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x)** (1 / _pos_var(x, 1)))), 'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)), 'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)), 'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0), 'proottanh': lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var(x, 3)) * mtf.tanh(x), 'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)), 'cosid': lambda x: mtf.cos(x) - x, 'minsin': lambda x: mtf.minimum(x, mtf.sin(x)), 'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)), 'mish': lambda x: x * mtf.tanh(mtf.softplus(x)), 'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)), 'lisht': lambda x: x * mtf.tanh(x), 'seagull': lambda x: mtf.log(1 + x**2), 'snake':