def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) x_dim = mtf.Dimension("nx", FLAGS.cube_size) y_dim = mtf.Dimension("ny", FLAGS.cube_size) z_dim = mtf.Dimension("nz", FLAGS.cube_size) tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) ty_dim = mtf.Dimension("tny", FLAGS.cube_size) tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) # Create field field = mtf.random_normal(mesh, [batch_dim, x_dim, y_dim, z_dim]) input_field = field field = mtf.cast(field, tf.complex64) err = 0 # Performs several back and forth FFTs in the same session for i in range(FLAGS.n_ffts): # Apply FFT fft_field = mpm.fft3d(field, [tx_dim, ty_dim, tz_dim]) # Inverse FFT field = mpm.ifft3d(fft_field * 1, [x_dim, y_dim, z_dim]) err += mtf.reduce_max(mtf.abs(mtf.cast(field, tf.float32) - input_field)) field = mtf.cast(field, tf.float32) # Compute errors err += mtf.reduce_max(mtf.abs(field - input_field)) return err
def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) x_dim = mtf.Dimension("nx", FLAGS.cube_size) y_dim = mtf.Dimension("ny", FLAGS.cube_size) z_dim = mtf.Dimension("nz", FLAGS.cube_size) tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) ty_dim = mtf.Dimension("tny", FLAGS.cube_size) tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) # Create field field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim]) # Apply FFT fft_field = mpm.fft3d(mtf.cast(field, tf.complex64), [tx_dim, ty_dim, tz_dim]) # Inverse FFT rfield = mtf.cast(mpm.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32) # Compute errors err = mtf.reduce_max(mtf.abs(field - rfield)) return err
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on. Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ ret = 0 n = -relative_position if bidirectional: num_buckets //= 2 ret += mtf.to_int32(mtf.less(n, 0)) * num_buckets n = mtf.abs(n) else: n = mtf.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = mtf.less(n, max_exact) val_if_large = max_exact + mtf.to_int32( mtf.log(mtf.to_float(n) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)) val_if_large = mtf.minimum(val_if_large, num_buckets - 1) ret += mtf.where(is_small, n, val_if_large) return ret
def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") def _arcsinh(x): return mtf.log(x + mtf.sqrt(1 + x**2)) def _var(x, init): return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], initializer=tf.constant_initializer(init), dtype=x.dtype) def _pos_var(x, val): return mtf.softplus(_var(x, 0)) + val if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "lrelu001": return lambda x: mtf.leaky_relu(x, alpha=0.01) elif activation_fn == "lrelu020": return lambda x: mtf.leaky_relu(x, alpha=0.20) elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn elif activation_fn == "silu": # https://arxiv.org/abs/1710.05941 return mtf.swish elif activation_fn == "arcsinh": return _arcsinh # parametric elif activation_fn == "aria": # https://arxiv.org/abs/1805.08878 return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var( x, 1) * mtf.exp(_var(x, -1) * x)**(1 / _pos_var(x, 1)))) elif activation_fn == "prelu": # https://arxiv.org/abs/1502.01852 return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)) elif activation_fn == "parcsinh": return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)) elif activation_fn == "psoftplus": return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0) elif activation_fn == "proottanh": return lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var( x, 3)) * mtf.tanh(x) # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn # swish activations elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941 return mtf.swish # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
'id': lambda x: x, 'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin( 7 * x) / 49, 'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos( 7 * x) / 7, 'spike': lambda x: 1 / (1 + x**2), 'spike2': lambda x: mtf.exp(-x**2), 'tanhshrink': lambda x: x - tanh(x), 'softsign': lambda x: x / (mtf.abs(x) + 1), 'softmax': lambda x: mtf.softmax(x, x.shape[-1]), 'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]), 'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1, 'rrelu': _rrelu, 'elish': _elish, 'arcsinh': _arcsinh, 'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x)**