Beispiel #1
0
def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None):
  """Faster RNG."""
  y = 1.0
  x = 0.0
  if FLAGS.use_bfloat16_activation:
    y = jnp.bfloat16(y)
    x = jnp.bfloat16(0.0)
    p = jnp.bfloat16(p)
  y = lax.tie_in(rng_key, y)
  m = lax.rng_uniform(x, y, shape)
  if FLAGS.use_bfloat16_activation:
    assert m.dtype == jnp.bfloat16
  return m < p
Beispiel #2
0
 def test_weak_types(self):
   mul = jax.jit(jnp.multiply)
   # The value `2` here should be weakly typed, and should not lead to
   # promotion.
   tf_fn = jax2tf.convert(lambda x: mul(x, 2.))
   self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(),
                       jnp.bfloat16(2.750))
Beispiel #3
0
    def test_bfloat16_constant(self):
        def jax_fn_scalar(x):
            x = x.astype(jnp.bfloat16)
            x *= 2.
            return x

        def jax_fn_array(x):
            x = x.astype(jnp.bfloat16)
            x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16)
            return x

        tf_fn_scalar = jax2tf.convert(jax_fn_scalar)
        self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750))

        tf_fn_array = jax2tf.convert(jax_fn_array)
        self.assertAllClose(tf_fn_array(np.array([3, 4, 5])),
                            np.array([4.5, 10, 17.5], jnp.bfloat16))
Beispiel #4
0
def main(_):
  # CHECK-LABEL: TEST: abs int32[]
  # CHECK: mhlo.abs
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(0))(lax.abs)

  # CHECK-LABEL: TEST: add float32[] float32[]
  # CHECK: mhlo.add
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.add)

  # CHECK-LABEL: TEST: acos float32[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1))(lax.acos)

  # CHECK-LABEL: TEST: acosh float32[]
  # CHECK: xla_fallback_acosh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.acosh)

  # CHECK-LABEL: TEST: asin float32[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1))(lax.asin)

  # CHECK-LABEL: TEST: asinh float32[]
  # CHECK: xla_fallback_asinh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.asinh)

  # CHECK-LABEL: TEST: atan float32[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1))(lax.atan)

  # CHECK-LABEL: TEST: atanh float32[]
  # CHECK: xla_fallback_atanh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.atanh)

  # CHECK-LABEL: TEST: atan2 float64[] float64[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f64>
  print_ir(np.float64(1), np.float64(2))(lax.atan2)

  # CHECK-LABEL: TEST: bessel_i0e float32[]
  # CHECK: xla_fallback_bessel_i0e
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.bessel_i0e)

  # CHECK-LABEL: TEST: bessel_i1e float32[]
  # CHECK: xla_fallback_bessel_i1e
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.bessel_i1e)

  # CHECK-LABEL: TEST: betainc float32[] float32[] float32[]
  # CHECK: xla_fallback_regularized_incomplete_beta
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)

  # CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
  # CHECK: mhlo.bitcast_convert
  # CHECK-SAME: tensor<7xui32>
  # CHECK-SAME: tensor<7xf32>
  print_ir(np.empty((7,), np.uint32))(
      partial(lax.bitcast_convert_type, new_dtype=np.float32))

  # CHECK-LABEL: TEST: bitwise_and int32[] int32[]
  # CHECK: mhlo.and
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(1), np.int32(2))(lax.bitwise_and)

  # CHECK-LABEL: TEST: bitwise_and bool[] bool[]
  # CHECK: mhlo.and
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and)

  # CHECK-LABEL: TEST: bitwise_or int32[] int32[]
  # CHECK: mhlo.or
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(1), np.int32(2))(lax.bitwise_or)

  # CHECK-LABEL: TEST: bitwise_or bool[] bool[]
  # CHECK: mhlo.or
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or)

  # CHECK-LABEL: TEST: bitwise_xor int32[] int32[]
  # CHECK: mhlo.xor
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor)

  # CHECK-LABEL: TEST: bitwise_xor bool[] bool[]
  # CHECK: mhlo.xor
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor)

  # CHECK-LABEL: TEST: cbrt bfloat16[]
  # CHECK: mhlo.cbrt
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(lax.cbrt)

  # CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[]
  # CHECK: mhlo.clamp
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp)

  # CHECK-LABEL: TEST: ceil float16[7]
  # CHECK: mhlo.ceil
  # CHECK-SAME: tensor<7xf16>
  print_ir(np.empty((7,), np.float16))(lax.ceil)

  # CHECK-LABEL: TEST: convert_element_type float16[7]
  # CHECK: mhlo.convert
  # CHECK-SAME: tensor<7xf16>
  # CHECK-SAME: tensor<7xf32>
  print_ir(np.empty((7,), np.float16))(
      partial(lax.convert_element_type, new_dtype=np.float32))

  # CHECK-LABEL: TEST: convert_element_type complex64[7]
  # CHECK: mhlo.real
  # CHECK-SAME: tensor<7xcomplex<f32>>
  # CHECK-SAME: tensor<7xf32>
  print_ir(np.empty((7,), np.complex64))(
      partial(lax.convert_element_type, new_dtype=np.float32))

  # CHECK-LABEL: TEST: convert_element_type float32[7]
  # CHECK: mhlo.compare
  # CHECK-SAME: tensor<7xf32>
  # CHECK-SAME: tensor<7xi1>
  print_ir(np.empty((7,), np.float32))(
      partial(lax.convert_element_type, new_dtype=np.bool_))

  # CHECK-LABEL: TEST: clz uint32[]
  # CHECK: mhlo.count_leading_zeros
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0))(lax.clz)

  # CHECK-LABEL: TEST: conj complex64[]
  # CHECK-DAG: mhlo.real
  # CHECK-DAG: mhlo.imag
  # CHECK-DAG: mhlo.neg
  # CHECK-DAG: mhlo.complex
  # CHECK-SAME: tensor<complex<f32>>
  print_ir(np.complex64(0))(lax.conj)

  # CHECK-LABEL: TEST: cos float32[]
  # CHECK: mhlo.cos
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.cos)

  # CHECK-LABEL: TEST: cosh float32[]
  # CHECK: xla_fallback_cosh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.cosh)

  # CHECK-LABEL: TEST: digamma float32[]
  # CHECK: chlo.digamma
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.digamma)

  # CHECK-LABEL: TEST: div float32[] float32[]
  # CHECK: mhlo.div
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.div)

  # CHECK-LABEL: TEST: eq float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.eq)

  # CHECK-LABEL: TEST: eq complex128[] complex128[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<complex<f64>>
  print_ir(np.complex128(1), np.complex128(2))(lax.eq)

  # CHECK-LABEL: TEST: eq int64[] int64[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type SIGNED">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(1), np.int64(2))(lax.eq)

  # CHECK-LABEL: TEST: eq uint16[] uint16[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type UNSIGNED">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<ui16>
  print_ir(np.uint16(1), np.uint16(2))(lax.eq)

  # CHECK-LABEL: TEST: erf float32[]
  # CHECK: xla_fallback_erf
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.erf)

  # CHECK-LABEL: TEST: erfc float32[]
  # CHECK: xla_fallback_erfc
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.erfc)

  # CHECK-LABEL: TEST: erf_inv float32[]
  # CHECK: xla_fallback_erf_inv
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.erf_inv)

  # CHECK-LABEL: TEST: exp float16[]
  # CHECK: mhlo.exp
  # CHECK-SAME: tensor<f16>
  print_ir(np.float16(0))(lax.exp)

  # CHECK-LABEL: TEST: expm1 bfloat16[]
  # CHECK: mhlo.exponential_minus_one
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(lax.expm1)

  # CHECK-LABEL: TEST: floor bfloat16[2,3]
  # CHECK: mhlo.floor
  # CHECK-SAME: tensor<2x3xbf16>
  print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor)

  # CHECK-LABEL: TEST: ge float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GE">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.ge)

  # CHECK-LABEL: TEST: gt float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GT">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.gt)

  # CHECK-LABEL: TEST: igamma float32[] float32[]
  # CHECK: xla_fallback_igamma
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.igamma)

  # CHECK-LABEL: TEST: igammac float32[] float32[]
  # CHECK: xla_fallback_igammac
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.igammac)

  # CHECK-LABEL: TEST: igamma_grad_a float32[] float32[]
  # CHECK: xla_fallback_igamma_grad_a
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)

  # CHECK-LABEL: TEST: imag complex64[]
  # CHECK: mhlo.imag
  # CHECK-SAME: tensor<complex<f32>>
  print_ir(np.complex64(0))(lax.imag)

  # CHECK-LABEL: TEST: integer_pow float32[]
  # CHECK-DAG: mhlo.mul
  # CHECK-SAME: tensor<f32>
  @print_ir(np.float32(1))
  def integer_pow(x): return lax.integer_pow(x, 3)

  # CHECK-LABEL: TEST: is_finite float64[]
  # CHECK: mhlo.is_finite
  # CHECK-SAME: tensor<f64>
  print_ir(np.float64(0))(lax.is_finite)

  # CHECK-LABEL: TEST: le float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LE">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.le)

  # CHECK-LABEL: TEST: lgamma float32[]
  # CHECK: chlo.lgamma
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.lgamma)

  # CHECK-LABEL: TEST: log float32[]
  # CHECK: mhlo.log
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.log)

  # CHECK-LABEL: TEST: log1p float32[]
  # CHECK: mhlo.log_plus_one
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.log1p)

  # CHECK-LABEL: TEST: lt float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LT">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.lt)

  # CHECK-LABEL: TEST: max float32[] float32[]
  # CHECK: mhlo.max
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.max)

  # CHECK-LABEL: TEST: min float32[] float32[]
  # CHECK: mhlo.min
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.min)

  # CHECK-LABEL: TEST: mul float32[] float32[]
  # CHECK: mhlo.mul
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.mul)

  # CHECK-LABEL: TEST: ne float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction NE">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.ne)

  # CHECK-LABEL: TEST: neg int64[]
  # CHECK: mhlo.negate
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(0))(lax.neg)

  # CHECK-LABEL: TEST: nextafter float32[] float32[]
  # CHECK: chlo.next_after
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.nextafter)

  # CHECK-LABEL: TEST: bitwise_not int64[]
  # CHECK: mhlo.not
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(0))(lax.bitwise_not)

  # CHECK-LABEL: TEST: bitwise_not bool[]
  # CHECK: mhlo.not
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0))(lax.bitwise_not)

  # CHECK-LABEL: TEST: population_count uint32[]
  # CHECK: mhlo.popcnt
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0))(lax.population_count)

  # CHECK-LABEL: TEST: pow float32[] float32[]
  # CHECK: mhlo.power
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.pow)

  # CHECK-LABEL: TEST: random_gamma_grad float32[] float32[]
  # CHECK: xla_fallback_random_gamma_grad
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)

  # CHECK-LABEL: TEST: real complex128[]
  # CHECK: mhlo.real
  # CHECK-SAME: tensor<complex<f64>>
  print_ir(np.complex128(0))(lax.real)

  # CHECK-LABEL: TEST: reduce_precision bfloat16[]
  # CHECK: mhlo.reduce_precision
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(
      partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2))

  # CHECK-LABEL: TEST: rem float32[] float32[]
  # CHECK: mhlo.rem
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.rem)

  # CHECK-LABEL: TEST: round float64[7,1]
  # CHECK: mhlo.round
  # CHECK-SAME: tensor<7x1xf64>
  print_ir(np.empty((7,1), np.float64))(
      partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO))

  # CHECK-LABEL: TEST: rsqrt complex64[]
  # CHECK: mhlo.rsqrt
  # CHECK-SAME: tensor<complex<f32>>
  print_ir(jnp.complex64(0))(lax.rsqrt)

  # CHECK-LABEL: TEST: shift_left uint32[] uint32[]
  # CHECK: mhlo.shift_left
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)

  # CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
  # CHECK: mhlo.shift_right_arithmetic
  # CHECK-SAME: tensor<ui8>
  print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)

  # CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[]
  # CHECK: mhlo.shift_right_logical
  # CHECK-SAME: tensor<ui16>
  print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical)

  # CHECK-LABEL: TEST: sign int64[]
  # CHECK: mhlo.sign
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(0))(lax.sign)

  # CHECK-LABEL: TEST: sign uint32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0))(lax.sign)

  # CHECK-LABEL: TEST: sin float32[]
  # CHECK: mhlo.sin
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.sin)

  # CHECK-LABEL: TEST: sinh float32[]
  # CHECK: xla_fallback_sinh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.sinh)

  # CHECK-LABEL: TEST: sub float32[] float32[]
  # CHECK: mhlo.sub
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.sub)

  # CHECK-LABEL: TEST: sqrt bfloat16[]
  # CHECK: mhlo.sqrt
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(lax.sqrt)

  # CHECK-LABEL: TEST: tan float16[]
  # CHECK: mhlo.sine
  # CHECK-SAME: tensor<f32>
  # CHECK: mhlo.cosine
  # CHECK-SAME: tensor<f32>
  print_ir(np.float16(0))(lax.tan)

  # CHECK-LABEL: TEST: tanh float32[]
  # CHECK: mhlo.tanh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.tanh)
Beispiel #5
0
def main(_):
    # CHECK-LABEL: TEST: bitwise_not bool[7]
    # CHECK: mhlo.not
    # CHECK-SAME: tensor<7xi1>
    print_ir(np.empty([7], np.bool_))(lax.bitwise_not)

    # CHECK-LABEL: TEST: neg int8[]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<i8>
    print_ir(np.int8(0))(lax.neg)

    # CHECK-LABEL: TEST: neg int16[0]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<0xi16>
    print_ir(np.empty([0], np.int16))(lax.neg)

    # CHECK-LABEL: TEST: neg int32[2,3]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<2x3xi32>
    print_ir(np.empty([2, 3], np.int32))(lax.neg)

    # CHECK-LABEL: TEST: neg int64[2,3,4]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<2x3x4xi64>
    print_ir(np.empty([2, 3, 4], np.int64))(lax.neg)

    # CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<4x0x1xui8>
    print_ir(np.empty([4, 0, 1], np.uint8), np.empty([4, 0, 1],
                                                     np.uint8))(lax.add)

    # CHECK-LABEL: TEST: add uint16[] uint16[]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<ui16>
    print_ir(np.uint16(0), np.uint16(0))(lax.add)

    # CHECK-LABEL: TEST: add uint32[] uint32[]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<ui32>
    print_ir(np.uint32(0), np.uint32(0))(lax.add)

    # CHECK-LABEL: TEST: add uint64[] uint64[]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<ui64>
    print_ir(np.uint64(0), np.uint64(0))(lax.add)

    # CHECK-LABEL: TEST: sin float16[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<f16>
    print_ir(np.float16(0))(lax.sin)

    # CHECK-LABEL: TEST: sin bfloat16[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<bf16>
    print_ir(jnp.bfloat16(0))(lax.sin)

    # CHECK-LABEL: TEST: sin float32[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<f32>
    print_ir(np.float32(0))(lax.sin)

    # CHECK-LABEL: TEST: sin float64[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<f64>
    print_ir(np.float64(0))(lax.sin)

    # CHECK-LABEL: TEST: cos complex64[]
    # CHECK: mhlo.cosine
    # CHECK-SAME: tensor<complex<f32>>
    print_ir(np.complex64(0))(lax.cos)

    # CHECK-LABEL: TEST: cos complex128[]
    # CHECK: mhlo.cosine
    # CHECK-SAME: tensor<complex<f64>>
    print_ir(np.complex128(0))(lax.cos)
Beispiel #6
0
def temporal_shift_tpu(x: types.TensorLike,
                       num_frames: int,
                       channel_shift_fraction: float = 0.125) -> jnp.ndarray:
    """Performs a temporal shift: https://arxiv.org/abs/1811.08383.

    TPU optimized version of TSM. Reshape is avoided by having the images
    reshaped in [T * B, :] so that frames corresponding to same time frame in
    videos are contiguous in memory. Thanks to cr/288510308 which allows to fuse
    pad->slice into convolution, we reformulate the slice pad into a pad then
    slice. Finally, to avoid concatenate that prevent some fusion from happening
    we simply sum masked version of the features.
  Args:
    x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped
      from a time major version of the input).
    num_frames: number of frames T per video.
    channel_shift_fraction: fraction of the channel to shift forward and
      backward.

  Returns:
      The temporal shifted version of x.
  """
    # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels
    # Input is (T * B, H, W, C)
    original_shape = list(x.shape)

    batch_size = int(original_shape[0] / num_frames)
    n_channels = int(original_shape[-1])
    n_shift = int(n_channels * channel_shift_fraction)

    # Cast to bfloat16.
    x = x.astype(jnp.bfloat16)

    # For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1.
    # Shift backward, we first pad by zeros [x1, x2, x3, 0, 0].
    orig_shp = list(x.shape)

    shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0),
                                (0, n_channels - n_shift, 0))
    x_backward_padding = jax.lax.pad(x,
                                     padding_value=jnp.bfloat16(0.),
                                     padding_config=shifted_backward_padding)
    # The following shift gets to [x3^+1, 0, 0] (where +1 means from the future).
    shifted_backward = jax.lax.slice(x_backward_padding,
                                     (batch_size, 0, 0, n_channels - n_shift),
                                     (orig_shp[0] + batch_size, orig_shp[1],
                                      orig_shp[2], 2 * n_channels - n_shift))
    # Shift forward, we first pad by zeros [0, 0, x1, x2, x3].
    shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0),
                               (n_channels - n_shift, 0, 0))
    x_forward_padding = jax.lax.pad(x,
                                    padding_value=jnp.bfloat16(0.),
                                    padding_config=shifted_forward_padding)
    # The following shift gets to [0, 0, x1^-1] (where -1 means from the past).
    shifted_forward = jax.lax.slice(
        x_forward_padding, (0, 0, 0, 0),
        (orig_shp[0], orig_shp[1], orig_shp[2], n_channels))
    # No shift is in the middle, this gets [0, x2, 0].
    mask_noshift = (jnp.reshape(
        (jnp.arange(n_channels) >= n_shift) &
        (jnp.arange(n_channels) < n_channels - n_shift),
        (1, 1, 1, -1))).astype(jnp.bfloat16)
    no_shift = mask_noshift * x
    # By summing everything together, we end up with [x3^+1, x2, x1^-1].
    # Note: channels have been reordered but that doesn't matter for the model.
    shifted_x = shifted_backward + shifted_forward + no_shift

    return shifted_x.astype(jnp.float32)
Beispiel #7
0
def self_attention(inputs,
                   variable_dictionary,
                   num_heads: int,
                   qkv_features: int = None,
                   padding_mask: List[bool] = None,
                   dropout_rate: float = 0.,
                   deterministic: bool = False,
                   precision: Precision = None,
                   kernel_init: List[float] = nn.linear.default_kernel_init,
                   bias_init: List[float] = nn.initializers.zeros,
                   dtype: jnp.dtype = jnp.float32,
                   bias: bool = True):
    """Applies Multi-head self-attention on the input data.

  Args:
    inputs: input data of shape `[bs, dim1, dim2, ..., dimN, features]`.
    variable_dictionary: Parameter dictionary.
    num_heads: number of attention heads. Features (i.e. inputs.shape[-1])
      should be divisible by the number of heads.
    qkv_features: dimension of the key, query, and value.
    padding_mask: boolean specifying tokens that are pad token.
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the kernel of the Dense layers.
    bias_init: initializer for the bias of the Dense layers.
    dtype: datatype for the activiations, jnp.bfloat16 or jnp.float32
    bias: bool: whether pointwise QKVO dense transforms use bias.

  Returns:
    output of shape `[bs, dim1, dim2, ..., dimN, features//num_heads]`.
  """

    features = inputs.shape[-1]
    qkv_features = qkv_features or features

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads
    inputs = inputs.astype(dtype)
    if FLAGS.use_einsum:
        dense_module = Dense3D
    else:
        dense_module = attention.DenseGeneral

    query = dense_module.call(variable_dictionary['query'],
                              inputs,
                              axis=-1,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision,
                              dtype=dtype,
                              name='query')
    query = jnp.multiply(query, 1.0 / math.sqrt(float(head_dim)))
    key = dense_module.call(variable_dictionary['key'],
                            inputs,
                            axis=-1,
                            features=(num_heads, head_dim),
                            kernel_init=kernel_init,
                            bias_init=bias_init,
                            bias=bias,
                            precision=precision,
                            dtype=dtype,
                            name='key')
    value = dense_module.call(variable_dictionary['value'],
                              inputs,
                              axis=-1,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision,
                              dtype=dtype,
                              name='value')

    assert query.dtype == dtype
    assert key.dtype == dtype
    assert value.dtype == dtype
    # get raw attention scores from dot product between key and query
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_heads`
    #   H = `head_dim` (qkv_features // num_heads)
    attention_scores = jnp.einsum('BTNH,BFNH->BNFT', key, query)
    assert attention_scores.dtype == dtype

    assert attention_scores.dtype == dtype
    # create attention masks
    if padding_mask is not None:
        assert padding_mask.dtype == bool, ('Mask should have bool type.')
        attention_mask = jnp.expand_dims(padding_mask, axis=1)
        adder = (1.0 - attention_mask) * NEG_INFINITY
        attention_scores += adder.astype(dtype)
    assert attention_scores.dtype == dtype

    attention_scores = attention_scores - lax.stop_gradient(
        jnp.max(attention_scores, axis=-1, keepdims=True))
    attention_scores = jnp.exp(attention_scores)
    attention_sum = jnp.sum(attention_scores, axis=-1, keepdims=True)

    keep_prob = 1 - dropout_rate
    if not deterministic:
        keep_mask = jax.random.bernoulli(nn.make_rng(), keep_prob,
                                         attention_scores.shape).astype(dtype)
        assert keep_mask.dtype == dtype
        attention_probs = jnp.multiply(keep_mask, attention_scores)
    else:
        attention_probs = attention_scores

    assert attention_probs.dtype == dtype

    attention_probs = jnp.einsum('BNFT,BTNH->BFNH', attention_probs, value)
    assert attention_probs.dtype == dtype
    attention_probs = attention_probs / jnp.transpose(attention_sum,
                                                      [0, 2, 1, 3])

    # split mask and scaling ops in dropout
    # move the scaling from dropout to here to save same mul ops
    # TODO(yuemmawang) automate this optimization in xla
    if not deterministic:
        scale = 1 / keep_prob
        if dtype == jnp.bfloat16:
            scale = jnp.bfloat16(scale)
        attention_probs = jnp.multiply(attention_probs, scale)
    assert attention_probs.dtype == dtype

    return attention_probs