def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None): """Faster RNG.""" y = 1.0 x = 0.0 if FLAGS.use_bfloat16_activation: y = jnp.bfloat16(y) x = jnp.bfloat16(0.0) p = jnp.bfloat16(p) y = lax.tie_in(rng_key, y) m = lax.rng_uniform(x, y, shape) if FLAGS.use_bfloat16_activation: assert m.dtype == jnp.bfloat16 return m < p
def test_weak_types(self): mul = jax.jit(jnp.multiply) # The value `2` here should be weakly typed, and should not lead to # promotion. tf_fn = jax2tf.convert(lambda x: mul(x, 2.)) self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(), jnp.bfloat16(2.750))
def test_bfloat16_constant(self): def jax_fn_scalar(x): x = x.astype(jnp.bfloat16) x *= 2. return x def jax_fn_array(x): x = x.astype(jnp.bfloat16) x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16) return x tf_fn_scalar = jax2tf.convert(jax_fn_scalar) self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750)) tf_fn_array = jax2tf.convert(jax_fn_array) self.assertAllClose(tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5], jnp.bfloat16))
def main(_): # CHECK-LABEL: TEST: abs int32[] # CHECK: mhlo.abs # CHECK-SAME: tensor<i32> print_ir(np.int32(0))(lax.abs) # CHECK-LABEL: TEST: add float32[] float32[] # CHECK: mhlo.add # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.add) # CHECK-LABEL: TEST: acos float32[] # CHECK: mhlo.atan2 # CHECK-SAME: tensor<f32> print_ir(np.float32(1))(lax.acos) # CHECK-LABEL: TEST: acosh float32[] # CHECK: xla_fallback_acosh # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.acosh) # CHECK-LABEL: TEST: asin float32[] # CHECK: mhlo.atan2 # CHECK-SAME: tensor<f32> print_ir(np.float32(1))(lax.asin) # CHECK-LABEL: TEST: asinh float32[] # CHECK: xla_fallback_asinh # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.asinh) # CHECK-LABEL: TEST: atan float32[] # CHECK: mhlo.atan2 # CHECK-SAME: tensor<f32> print_ir(np.float32(1))(lax.atan) # CHECK-LABEL: TEST: atanh float32[] # CHECK: xla_fallback_atanh # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.atanh) # CHECK-LABEL: TEST: atan2 float64[] float64[] # CHECK: mhlo.atan2 # CHECK-SAME: tensor<f64> print_ir(np.float64(1), np.float64(2))(lax.atan2) # CHECK-LABEL: TEST: bessel_i0e float32[] # CHECK: xla_fallback_bessel_i0e # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.bessel_i0e) # CHECK-LABEL: TEST: bessel_i1e float32[] # CHECK: xla_fallback_bessel_i1e # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.bessel_i1e) # CHECK-LABEL: TEST: betainc float32[] float32[] float32[] # CHECK: xla_fallback_regularized_incomplete_beta # CHECK-SAME: tensor<f32> print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc) # CHECK-LABEL: TEST: bitcast_convert_type uint32[7] # CHECK: mhlo.bitcast_convert # CHECK-SAME: tensor<7xui32> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.uint32))( partial(lax.bitcast_convert_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: bitwise_and int32[] int32[] # CHECK: mhlo.and # CHECK-SAME: tensor<i32> print_ir(np.int32(1), np.int32(2))(lax.bitwise_and) # CHECK-LABEL: TEST: bitwise_and bool[] bool[] # CHECK: mhlo.and # CHECK-SAME: tensor<i1> print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and) # CHECK-LABEL: TEST: bitwise_or int32[] int32[] # CHECK: mhlo.or # CHECK-SAME: tensor<i32> print_ir(np.int32(1), np.int32(2))(lax.bitwise_or) # CHECK-LABEL: TEST: bitwise_or bool[] bool[] # CHECK: mhlo.or # CHECK-SAME: tensor<i1> print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or) # CHECK-LABEL: TEST: bitwise_xor int32[] int32[] # CHECK: mhlo.xor # CHECK-SAME: tensor<i32> print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor) # CHECK-LABEL: TEST: bitwise_xor bool[] bool[] # CHECK: mhlo.xor # CHECK-SAME: tensor<i1> print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor) # CHECK-LABEL: TEST: cbrt bfloat16[] # CHECK: mhlo.cbrt # CHECK-SAME: tensor<bf16> print_ir(jnp.bfloat16(0))(lax.cbrt) # CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[] # CHECK: mhlo.clamp # CHECK-SAME: tensor<bf16> print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp) # CHECK-LABEL: TEST: ceil float16[7] # CHECK: mhlo.ceil # CHECK-SAME: tensor<7xf16> print_ir(np.empty((7,), np.float16))(lax.ceil) # CHECK-LABEL: TEST: convert_element_type float16[7] # CHECK: mhlo.convert # CHECK-SAME: tensor<7xf16> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.float16))( partial(lax.convert_element_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: convert_element_type complex64[7] # CHECK: mhlo.real # CHECK-SAME: tensor<7xcomplex<f32>> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.complex64))( partial(lax.convert_element_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: convert_element_type float32[7] # CHECK: mhlo.compare # CHECK-SAME: tensor<7xf32> # CHECK-SAME: tensor<7xi1> print_ir(np.empty((7,), np.float32))( partial(lax.convert_element_type, new_dtype=np.bool_)) # CHECK-LABEL: TEST: clz uint32[] # CHECK: mhlo.count_leading_zeros # CHECK-SAME: tensor<ui32> print_ir(np.uint32(0))(lax.clz) # CHECK-LABEL: TEST: conj complex64[] # CHECK-DAG: mhlo.real # CHECK-DAG: mhlo.imag # CHECK-DAG: mhlo.neg # CHECK-DAG: mhlo.complex # CHECK-SAME: tensor<complex<f32>> print_ir(np.complex64(0))(lax.conj) # CHECK-LABEL: TEST: cos float32[] # CHECK: mhlo.cos # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.cos) # CHECK-LABEL: TEST: cosh float32[] # CHECK: xla_fallback_cosh # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.cosh) # CHECK-LABEL: TEST: digamma float32[] # CHECK: chlo.digamma # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.digamma) # CHECK-LABEL: TEST: div float32[] float32[] # CHECK: mhlo.div # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.div) # CHECK-LABEL: TEST: eq float32[] float32[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.eq) # CHECK-LABEL: TEST: eq complex128[] complex128[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor<complex<f64>> print_ir(np.complex128(1), np.complex128(2))(lax.eq) # CHECK-LABEL: TEST: eq int64[] int64[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type SIGNED"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor<i64> print_ir(np.int64(1), np.int64(2))(lax.eq) # CHECK-LABEL: TEST: eq uint16[] uint16[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type UNSIGNED"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor<ui16> print_ir(np.uint16(1), np.uint16(2))(lax.eq) # CHECK-LABEL: TEST: erf float32[] # CHECK: xla_fallback_erf # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.erf) # CHECK-LABEL: TEST: erfc float32[] # CHECK: xla_fallback_erfc # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.erfc) # CHECK-LABEL: TEST: erf_inv float32[] # CHECK: xla_fallback_erf_inv # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.erf_inv) # CHECK-LABEL: TEST: exp float16[] # CHECK: mhlo.exp # CHECK-SAME: tensor<f16> print_ir(np.float16(0))(lax.exp) # CHECK-LABEL: TEST: expm1 bfloat16[] # CHECK: mhlo.exponential_minus_one # CHECK-SAME: tensor<bf16> print_ir(jnp.bfloat16(0))(lax.expm1) # CHECK-LABEL: TEST: floor bfloat16[2,3] # CHECK: mhlo.floor # CHECK-SAME: tensor<2x3xbf16> print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor) # CHECK-LABEL: TEST: ge float32[] float32[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GE"> # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.ge) # CHECK-LABEL: TEST: gt float32[] float32[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GT"> # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.gt) # CHECK-LABEL: TEST: igamma float32[] float32[] # CHECK: xla_fallback_igamma # CHECK-SAME: tensor<f32> print_ir(np.float32(0), np.float32(0))(lax.igamma) # CHECK-LABEL: TEST: igammac float32[] float32[] # CHECK: xla_fallback_igammac # CHECK-SAME: tensor<f32> print_ir(np.float32(0), np.float32(0))(lax.igammac) # CHECK-LABEL: TEST: igamma_grad_a float32[] float32[] # CHECK: xla_fallback_igamma_grad_a # CHECK-SAME: tensor<f32> print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a) # CHECK-LABEL: TEST: imag complex64[] # CHECK: mhlo.imag # CHECK-SAME: tensor<complex<f32>> print_ir(np.complex64(0))(lax.imag) # CHECK-LABEL: TEST: integer_pow float32[] # CHECK-DAG: mhlo.mul # CHECK-SAME: tensor<f32> @print_ir(np.float32(1)) def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-LABEL: TEST: is_finite float64[] # CHECK: mhlo.is_finite # CHECK-SAME: tensor<f64> print_ir(np.float64(0))(lax.is_finite) # CHECK-LABEL: TEST: le float32[] float32[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LE"> # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.le) # CHECK-LABEL: TEST: lgamma float32[] # CHECK: chlo.lgamma # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.lgamma) # CHECK-LABEL: TEST: log float32[] # CHECK: mhlo.log # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.log) # CHECK-LABEL: TEST: log1p float32[] # CHECK: mhlo.log_plus_one # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.log1p) # CHECK-LABEL: TEST: lt float32[] float32[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LT"> # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.lt) # CHECK-LABEL: TEST: max float32[] float32[] # CHECK: mhlo.max # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.max) # CHECK-LABEL: TEST: min float32[] float32[] # CHECK: mhlo.min # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.min) # CHECK-LABEL: TEST: mul float32[] float32[] # CHECK: mhlo.mul # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.mul) # CHECK-LABEL: TEST: ne float32[] float32[] # CHECK: mhlo.compare # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction NE"> # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.ne) # CHECK-LABEL: TEST: neg int64[] # CHECK: mhlo.negate # CHECK-SAME: tensor<i64> print_ir(np.int64(0))(lax.neg) # CHECK-LABEL: TEST: nextafter float32[] float32[] # CHECK: chlo.next_after # CHECK-SAME: tensor<f32> print_ir(np.float32(0), np.float32(0))(lax.nextafter) # CHECK-LABEL: TEST: bitwise_not int64[] # CHECK: mhlo.not # CHECK-SAME: tensor<i64> print_ir(np.int64(0))(lax.bitwise_not) # CHECK-LABEL: TEST: bitwise_not bool[] # CHECK: mhlo.not # CHECK-SAME: tensor<i1> print_ir(np.bool_(0))(lax.bitwise_not) # CHECK-LABEL: TEST: population_count uint32[] # CHECK: mhlo.popcnt # CHECK-SAME: tensor<ui32> print_ir(np.uint32(0))(lax.population_count) # CHECK-LABEL: TEST: pow float32[] float32[] # CHECK: mhlo.power # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.pow) # CHECK-LABEL: TEST: random_gamma_grad float32[] float32[] # CHECK: xla_fallback_random_gamma_grad # CHECK-SAME: tensor<f32> print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad) # CHECK-LABEL: TEST: real complex128[] # CHECK: mhlo.real # CHECK-SAME: tensor<complex<f64>> print_ir(np.complex128(0))(lax.real) # CHECK-LABEL: TEST: reduce_precision bfloat16[] # CHECK: mhlo.reduce_precision # CHECK-SAME: tensor<bf16> print_ir(jnp.bfloat16(0))( partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2)) # CHECK-LABEL: TEST: rem float32[] float32[] # CHECK: mhlo.rem # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.rem) # CHECK-LABEL: TEST: round float64[7,1] # CHECK: mhlo.round # CHECK-SAME: tensor<7x1xf64> print_ir(np.empty((7,1), np.float64))( partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO)) # CHECK-LABEL: TEST: rsqrt complex64[] # CHECK: mhlo.rsqrt # CHECK-SAME: tensor<complex<f32>> print_ir(jnp.complex64(0))(lax.rsqrt) # CHECK-LABEL: TEST: shift_left uint32[] uint32[] # CHECK: mhlo.shift_left # CHECK-SAME: tensor<ui32> print_ir(np.uint32(0), np.uint32(0))(lax.shift_left) # CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[] # CHECK: mhlo.shift_right_arithmetic # CHECK-SAME: tensor<ui8> print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic) # CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[] # CHECK: mhlo.shift_right_logical # CHECK-SAME: tensor<ui16> print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical) # CHECK-LABEL: TEST: sign int64[] # CHECK: mhlo.sign # CHECK-SAME: tensor<i64> print_ir(np.int64(0))(lax.sign) # CHECK-LABEL: TEST: sign uint32[] # CHECK: mhlo.compare # CHECK-SAME: tensor<ui32> print_ir(np.uint32(0))(lax.sign) # CHECK-LABEL: TEST: sin float32[] # CHECK: mhlo.sin # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.sin) # CHECK-LABEL: TEST: sinh float32[] # CHECK: xla_fallback_sinh # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.sinh) # CHECK-LABEL: TEST: sub float32[] float32[] # CHECK: mhlo.sub # CHECK-SAME: tensor<f32> print_ir(np.float32(1), np.float32(2))(lax.sub) # CHECK-LABEL: TEST: sqrt bfloat16[] # CHECK: mhlo.sqrt # CHECK-SAME: tensor<bf16> print_ir(jnp.bfloat16(0))(lax.sqrt) # CHECK-LABEL: TEST: tan float16[] # CHECK: mhlo.sine # CHECK-SAME: tensor<f32> # CHECK: mhlo.cosine # CHECK-SAME: tensor<f32> print_ir(np.float16(0))(lax.tan) # CHECK-LABEL: TEST: tanh float32[] # CHECK: mhlo.tanh # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.tanh)
def main(_): # CHECK-LABEL: TEST: bitwise_not bool[7] # CHECK: mhlo.not # CHECK-SAME: tensor<7xi1> print_ir(np.empty([7], np.bool_))(lax.bitwise_not) # CHECK-LABEL: TEST: neg int8[] # CHECK: mhlo.negate # CHECK-SAME: tensor<i8> print_ir(np.int8(0))(lax.neg) # CHECK-LABEL: TEST: neg int16[0] # CHECK: mhlo.negate # CHECK-SAME: tensor<0xi16> print_ir(np.empty([0], np.int16))(lax.neg) # CHECK-LABEL: TEST: neg int32[2,3] # CHECK: mhlo.negate # CHECK-SAME: tensor<2x3xi32> print_ir(np.empty([2, 3], np.int32))(lax.neg) # CHECK-LABEL: TEST: neg int64[2,3,4] # CHECK: mhlo.negate # CHECK-SAME: tensor<2x3x4xi64> print_ir(np.empty([2, 3, 4], np.int64))(lax.neg) # CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1] # CHECK: mhlo.add # CHECK-SAME: tensor<4x0x1xui8> print_ir(np.empty([4, 0, 1], np.uint8), np.empty([4, 0, 1], np.uint8))(lax.add) # CHECK-LABEL: TEST: add uint16[] uint16[] # CHECK: mhlo.add # CHECK-SAME: tensor<ui16> print_ir(np.uint16(0), np.uint16(0))(lax.add) # CHECK-LABEL: TEST: add uint32[] uint32[] # CHECK: mhlo.add # CHECK-SAME: tensor<ui32> print_ir(np.uint32(0), np.uint32(0))(lax.add) # CHECK-LABEL: TEST: add uint64[] uint64[] # CHECK: mhlo.add # CHECK-SAME: tensor<ui64> print_ir(np.uint64(0), np.uint64(0))(lax.add) # CHECK-LABEL: TEST: sin float16[] # CHECK: mhlo.sine # CHECK-SAME: tensor<f16> print_ir(np.float16(0))(lax.sin) # CHECK-LABEL: TEST: sin bfloat16[] # CHECK: mhlo.sine # CHECK-SAME: tensor<bf16> print_ir(jnp.bfloat16(0))(lax.sin) # CHECK-LABEL: TEST: sin float32[] # CHECK: mhlo.sine # CHECK-SAME: tensor<f32> print_ir(np.float32(0))(lax.sin) # CHECK-LABEL: TEST: sin float64[] # CHECK: mhlo.sine # CHECK-SAME: tensor<f64> print_ir(np.float64(0))(lax.sin) # CHECK-LABEL: TEST: cos complex64[] # CHECK: mhlo.cosine # CHECK-SAME: tensor<complex<f32>> print_ir(np.complex64(0))(lax.cos) # CHECK-LABEL: TEST: cos complex128[] # CHECK: mhlo.cosine # CHECK-SAME: tensor<complex<f64>> print_ir(np.complex128(0))(lax.cos)
def temporal_shift_tpu(x: types.TensorLike, num_frames: int, channel_shift_fraction: float = 0.125) -> jnp.ndarray: """Performs a temporal shift: https://arxiv.org/abs/1811.08383. TPU optimized version of TSM. Reshape is avoided by having the images reshaped in [T * B, :] so that frames corresponding to same time frame in videos are contiguous in memory. Thanks to cr/288510308 which allows to fuse pad->slice into convolution, we reformulate the slice pad into a pad then slice. Finally, to avoid concatenate that prevent some fusion from happening we simply sum masked version of the features. Args: x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped from a time major version of the input). num_frames: number of frames T per video. channel_shift_fraction: fraction of the channel to shift forward and backward. Returns: The temporal shifted version of x. """ # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels # Input is (T * B, H, W, C) original_shape = list(x.shape) batch_size = int(original_shape[0] / num_frames) n_channels = int(original_shape[-1]) n_shift = int(n_channels * channel_shift_fraction) # Cast to bfloat16. x = x.astype(jnp.bfloat16) # For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1. # Shift backward, we first pad by zeros [x1, x2, x3, 0, 0]. orig_shp = list(x.shape) shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0), (0, n_channels - n_shift, 0)) x_backward_padding = jax.lax.pad(x, padding_value=jnp.bfloat16(0.), padding_config=shifted_backward_padding) # The following shift gets to [x3^+1, 0, 0] (where +1 means from the future). shifted_backward = jax.lax.slice(x_backward_padding, (batch_size, 0, 0, n_channels - n_shift), (orig_shp[0] + batch_size, orig_shp[1], orig_shp[2], 2 * n_channels - n_shift)) # Shift forward, we first pad by zeros [0, 0, x1, x2, x3]. shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0), (n_channels - n_shift, 0, 0)) x_forward_padding = jax.lax.pad(x, padding_value=jnp.bfloat16(0.), padding_config=shifted_forward_padding) # The following shift gets to [0, 0, x1^-1] (where -1 means from the past). shifted_forward = jax.lax.slice( x_forward_padding, (0, 0, 0, 0), (orig_shp[0], orig_shp[1], orig_shp[2], n_channels)) # No shift is in the middle, this gets [0, x2, 0]. mask_noshift = (jnp.reshape( (jnp.arange(n_channels) >= n_shift) & (jnp.arange(n_channels) < n_channels - n_shift), (1, 1, 1, -1))).astype(jnp.bfloat16) no_shift = mask_noshift * x # By summing everything together, we end up with [x3^+1, x2, x1^-1]. # Note: channels have been reordered but that doesn't matter for the model. shifted_x = shifted_backward + shifted_forward + no_shift return shifted_x.astype(jnp.float32)
def self_attention(inputs, variable_dictionary, num_heads: int, qkv_features: int = None, padding_mask: List[bool] = None, dropout_rate: float = 0., deterministic: bool = False, precision: Precision = None, kernel_init: List[float] = nn.linear.default_kernel_init, bias_init: List[float] = nn.initializers.zeros, dtype: jnp.dtype = jnp.float32, bias: bool = True): """Applies Multi-head self-attention on the input data. Args: inputs: input data of shape `[bs, dim1, dim2, ..., dimN, features]`. variable_dictionary: Parameter dictionary. num_heads: number of attention heads. Features (i.e. inputs.shape[-1]) should be divisible by the number of heads. qkv_features: dimension of the key, query, and value. padding_mask: boolean specifying tokens that are pad token. dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. dtype: datatype for the activiations, jnp.bfloat16 or jnp.float32 bias: bool: whether pointwise QKVO dense transforms use bias. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features//num_heads]`. """ features = inputs.shape[-1] qkv_features = qkv_features or features assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads inputs = inputs.astype(dtype) if FLAGS.use_einsum: dense_module = Dense3D else: dense_module = attention.DenseGeneral query = dense_module.call(variable_dictionary['query'], inputs, axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, dtype=dtype, name='query') query = jnp.multiply(query, 1.0 / math.sqrt(float(head_dim))) key = dense_module.call(variable_dictionary['key'], inputs, axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, dtype=dtype, name='key') value = dense_module.call(variable_dictionary['value'], inputs, axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, dtype=dtype, name='value') assert query.dtype == dtype assert key.dtype == dtype assert value.dtype == dtype # get raw attention scores from dot product between key and query # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_heads` # H = `head_dim` (qkv_features // num_heads) attention_scores = jnp.einsum('BTNH,BFNH->BNFT', key, query) assert attention_scores.dtype == dtype assert attention_scores.dtype == dtype # create attention masks if padding_mask is not None: assert padding_mask.dtype == bool, ('Mask should have bool type.') attention_mask = jnp.expand_dims(padding_mask, axis=1) adder = (1.0 - attention_mask) * NEG_INFINITY attention_scores += adder.astype(dtype) assert attention_scores.dtype == dtype attention_scores = attention_scores - lax.stop_gradient( jnp.max(attention_scores, axis=-1, keepdims=True)) attention_scores = jnp.exp(attention_scores) attention_sum = jnp.sum(attention_scores, axis=-1, keepdims=True) keep_prob = 1 - dropout_rate if not deterministic: keep_mask = jax.random.bernoulli(nn.make_rng(), keep_prob, attention_scores.shape).astype(dtype) assert keep_mask.dtype == dtype attention_probs = jnp.multiply(keep_mask, attention_scores) else: attention_probs = attention_scores assert attention_probs.dtype == dtype attention_probs = jnp.einsum('BNFT,BTNH->BFNH', attention_probs, value) assert attention_probs.dtype == dtype attention_probs = attention_probs / jnp.transpose(attention_sum, [0, 2, 1, 3]) # split mask and scaling ops in dropout # move the scaling from dropout to here to save same mul ops # TODO(yuemmawang) automate this optimization in xla if not deterministic: scale = 1 / keep_prob if dtype == jnp.bfloat16: scale = jnp.bfloat16(scale) attention_probs = jnp.multiply(attention_probs, scale) assert attention_probs.dtype == dtype return attention_probs