Ejemplo n.º 1
0
def main(_):
  # CHECK-LABEL: TEST: neg int32[7]
  # CHECK: module @jit_neg
  # CHECK: func public @main
  print_ir(np.empty([7], np.int32))(lax.neg)

  # CHECK-LABEL: TEST: foo int32[7]
  # CHECK: module @jit_foo
  # CHECK: func public @main
  @print_ir(np.empty([7], np.int32))
  @jax.jit
  def foo(x): return x + 2
Ejemplo n.º 2
0
def main(_):
  # CHECK-LABEL: TEST: abs int32[]
  # CHECK: mhlo.abs
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(0))(lax.abs)

  # CHECK-LABEL: TEST: add float32[] float32[]
  # CHECK: mhlo.add
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.add)

  # CHECK-LABEL: TEST: acos float32[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1))(lax.acos)

  # CHECK-LABEL: TEST: acosh float32[]
  # CHECK: xla_fallback_acosh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.acosh)

  # CHECK-LABEL: TEST: asin float32[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1))(lax.asin)

  # CHECK-LABEL: TEST: asinh float32[]
  # CHECK: xla_fallback_asinh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.asinh)

  # CHECK-LABEL: TEST: atan float32[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1))(lax.atan)

  # CHECK-LABEL: TEST: atanh float32[]
  # CHECK: xla_fallback_atanh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.atanh)

  # CHECK-LABEL: TEST: atan2 float64[] float64[]
  # CHECK: mhlo.atan2
  # CHECK-SAME: tensor<f64>
  print_ir(np.float64(1), np.float64(2))(lax.atan2)

  # CHECK-LABEL: TEST: bessel_i0e float32[]
  # CHECK: xla_fallback_bessel_i0e
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.bessel_i0e)

  # CHECK-LABEL: TEST: bessel_i1e float32[]
  # CHECK: xla_fallback_bessel_i1e
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.bessel_i1e)

  # CHECK-LABEL: TEST: betainc float32[] float32[] float32[]
  # CHECK: xla_fallback_regularized_incomplete_beta
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)

  # CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
  # CHECK: mhlo.bitcast_convert
  # CHECK-SAME: tensor<7xui32>
  # CHECK-SAME: tensor<7xf32>
  print_ir(np.empty((7,), np.uint32))(
      partial(lax.bitcast_convert_type, new_dtype=np.float32))

  # CHECK-LABEL: TEST: bitwise_and int32[] int32[]
  # CHECK: mhlo.and
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(1), np.int32(2))(lax.bitwise_and)

  # CHECK-LABEL: TEST: bitwise_and bool[] bool[]
  # CHECK: mhlo.and
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and)

  # CHECK-LABEL: TEST: bitwise_or int32[] int32[]
  # CHECK: mhlo.or
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(1), np.int32(2))(lax.bitwise_or)

  # CHECK-LABEL: TEST: bitwise_or bool[] bool[]
  # CHECK: mhlo.or
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or)

  # CHECK-LABEL: TEST: bitwise_xor int32[] int32[]
  # CHECK: mhlo.xor
  # CHECK-SAME: tensor<i32>
  print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor)

  # CHECK-LABEL: TEST: bitwise_xor bool[] bool[]
  # CHECK: mhlo.xor
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor)

  # CHECK-LABEL: TEST: cbrt bfloat16[]
  # CHECK: mhlo.cbrt
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(lax.cbrt)

  # CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[]
  # CHECK: mhlo.clamp
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp)

  # CHECK-LABEL: TEST: ceil float16[7]
  # CHECK: mhlo.ceil
  # CHECK-SAME: tensor<7xf16>
  print_ir(np.empty((7,), np.float16))(lax.ceil)

  # CHECK-LABEL: TEST: convert_element_type float16[7]
  # CHECK: mhlo.convert
  # CHECK-SAME: tensor<7xf16>
  # CHECK-SAME: tensor<7xf32>
  print_ir(np.empty((7,), np.float16))(
      partial(lax.convert_element_type, new_dtype=np.float32))

  # CHECK-LABEL: TEST: convert_element_type complex64[7]
  # CHECK: mhlo.real
  # CHECK-SAME: tensor<7xcomplex<f32>>
  # CHECK-SAME: tensor<7xf32>
  print_ir(np.empty((7,), np.complex64))(
      partial(lax.convert_element_type, new_dtype=np.float32))

  # CHECK-LABEL: TEST: convert_element_type float32[7]
  # CHECK: mhlo.compare
  # CHECK-SAME: tensor<7xf32>
  # CHECK-SAME: tensor<7xi1>
  print_ir(np.empty((7,), np.float32))(
      partial(lax.convert_element_type, new_dtype=np.bool_))

  # CHECK-LABEL: TEST: clz uint32[]
  # CHECK: mhlo.count_leading_zeros
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0))(lax.clz)

  # CHECK-LABEL: TEST: conj complex64[]
  # CHECK-DAG: mhlo.real
  # CHECK-DAG: mhlo.imag
  # CHECK-DAG: mhlo.neg
  # CHECK-DAG: mhlo.complex
  # CHECK-SAME: tensor<complex<f32>>
  print_ir(np.complex64(0))(lax.conj)

  # CHECK-LABEL: TEST: cos float32[]
  # CHECK: mhlo.cos
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.cos)

  # CHECK-LABEL: TEST: cosh float32[]
  # CHECK: xla_fallback_cosh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.cosh)

  # CHECK-LABEL: TEST: digamma float32[]
  # CHECK: chlo.digamma
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.digamma)

  # CHECK-LABEL: TEST: div float32[] float32[]
  # CHECK: mhlo.div
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.div)

  # CHECK-LABEL: TEST: eq float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.eq)

  # CHECK-LABEL: TEST: eq complex128[] complex128[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<complex<f64>>
  print_ir(np.complex128(1), np.complex128(2))(lax.eq)

  # CHECK-LABEL: TEST: eq int64[] int64[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type SIGNED">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(1), np.int64(2))(lax.eq)

  # CHECK-LABEL: TEST: eq uint16[] uint16[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type UNSIGNED">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
  # CHECK-SAME: tensor<ui16>
  print_ir(np.uint16(1), np.uint16(2))(lax.eq)

  # CHECK-LABEL: TEST: erf float32[]
  # CHECK: xla_fallback_erf
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.erf)

  # CHECK-LABEL: TEST: erfc float32[]
  # CHECK: xla_fallback_erfc
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.erfc)

  # CHECK-LABEL: TEST: erf_inv float32[]
  # CHECK: xla_fallback_erf_inv
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.erf_inv)

  # CHECK-LABEL: TEST: exp float16[]
  # CHECK: mhlo.exp
  # CHECK-SAME: tensor<f16>
  print_ir(np.float16(0))(lax.exp)

  # CHECK-LABEL: TEST: expm1 bfloat16[]
  # CHECK: mhlo.exponential_minus_one
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(lax.expm1)

  # CHECK-LABEL: TEST: floor bfloat16[2,3]
  # CHECK: mhlo.floor
  # CHECK-SAME: tensor<2x3xbf16>
  print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor)

  # CHECK-LABEL: TEST: ge float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GE">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.ge)

  # CHECK-LABEL: TEST: gt float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GT">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.gt)

  # CHECK-LABEL: TEST: igamma float32[] float32[]
  # CHECK: xla_fallback_igamma
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.igamma)

  # CHECK-LABEL: TEST: igammac float32[] float32[]
  # CHECK: xla_fallback_igammac
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.igammac)

  # CHECK-LABEL: TEST: igamma_grad_a float32[] float32[]
  # CHECK: xla_fallback_igamma_grad_a
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)

  # CHECK-LABEL: TEST: imag complex64[]
  # CHECK: mhlo.imag
  # CHECK-SAME: tensor<complex<f32>>
  print_ir(np.complex64(0))(lax.imag)

  # CHECK-LABEL: TEST: integer_pow float32[]
  # CHECK-DAG: mhlo.mul
  # CHECK-SAME: tensor<f32>
  @print_ir(np.float32(1))
  def integer_pow(x): return lax.integer_pow(x, 3)

  # CHECK-LABEL: TEST: is_finite float64[]
  # CHECK: mhlo.is_finite
  # CHECK-SAME: tensor<f64>
  print_ir(np.float64(0))(lax.is_finite)

  # CHECK-LABEL: TEST: le float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LE">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.le)

  # CHECK-LABEL: TEST: lgamma float32[]
  # CHECK: chlo.lgamma
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.lgamma)

  # CHECK-LABEL: TEST: log float32[]
  # CHECK: mhlo.log
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.log)

  # CHECK-LABEL: TEST: log1p float32[]
  # CHECK: mhlo.log_plus_one
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.log1p)

  # CHECK-LABEL: TEST: lt float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LT">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.lt)

  # CHECK-LABEL: TEST: max float32[] float32[]
  # CHECK: mhlo.max
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.max)

  # CHECK-LABEL: TEST: min float32[] float32[]
  # CHECK: mhlo.min
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.min)

  # CHECK-LABEL: TEST: mul float32[] float32[]
  # CHECK: mhlo.mul
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.mul)

  # CHECK-LABEL: TEST: ne float32[] float32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
  # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction NE">
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.ne)

  # CHECK-LABEL: TEST: neg int64[]
  # CHECK: mhlo.negate
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(0))(lax.neg)

  # CHECK-LABEL: TEST: nextafter float32[] float32[]
  # CHECK: chlo.next_after
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.nextafter)

  # CHECK-LABEL: TEST: bitwise_not int64[]
  # CHECK: mhlo.not
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(0))(lax.bitwise_not)

  # CHECK-LABEL: TEST: bitwise_not bool[]
  # CHECK: mhlo.not
  # CHECK-SAME: tensor<i1>
  print_ir(np.bool_(0))(lax.bitwise_not)

  # CHECK-LABEL: TEST: population_count uint32[]
  # CHECK: mhlo.popcnt
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0))(lax.population_count)

  # CHECK-LABEL: TEST: pow float32[] float32[]
  # CHECK: mhlo.power
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.pow)

  # CHECK-LABEL: TEST: random_gamma_grad float32[] float32[]
  # CHECK: xla_fallback_random_gamma_grad
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)

  # CHECK-LABEL: TEST: real complex128[]
  # CHECK: mhlo.real
  # CHECK-SAME: tensor<complex<f64>>
  print_ir(np.complex128(0))(lax.real)

  # CHECK-LABEL: TEST: reduce_precision bfloat16[]
  # CHECK: mhlo.reduce_precision
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(
      partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2))

  # CHECK-LABEL: TEST: rem float32[] float32[]
  # CHECK: mhlo.rem
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.rem)

  # CHECK-LABEL: TEST: round float64[7,1]
  # CHECK: mhlo.round
  # CHECK-SAME: tensor<7x1xf64>
  print_ir(np.empty((7,1), np.float64))(
      partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO))

  # CHECK-LABEL: TEST: rsqrt complex64[]
  # CHECK: mhlo.rsqrt
  # CHECK-SAME: tensor<complex<f32>>
  print_ir(jnp.complex64(0))(lax.rsqrt)

  # CHECK-LABEL: TEST: shift_left uint32[] uint32[]
  # CHECK: mhlo.shift_left
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)

  # CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
  # CHECK: mhlo.shift_right_arithmetic
  # CHECK-SAME: tensor<ui8>
  print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)

  # CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[]
  # CHECK: mhlo.shift_right_logical
  # CHECK-SAME: tensor<ui16>
  print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical)

  # CHECK-LABEL: TEST: sign int64[]
  # CHECK: mhlo.sign
  # CHECK-SAME: tensor<i64>
  print_ir(np.int64(0))(lax.sign)

  # CHECK-LABEL: TEST: sign uint32[]
  # CHECK: mhlo.compare
  # CHECK-SAME: tensor<ui32>
  print_ir(np.uint32(0))(lax.sign)

  # CHECK-LABEL: TEST: sin float32[]
  # CHECK: mhlo.sin
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.sin)

  # CHECK-LABEL: TEST: sinh float32[]
  # CHECK: xla_fallback_sinh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.sinh)

  # CHECK-LABEL: TEST: sub float32[] float32[]
  # CHECK: mhlo.sub
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(1), np.float32(2))(lax.sub)

  # CHECK-LABEL: TEST: sqrt bfloat16[]
  # CHECK: mhlo.sqrt
  # CHECK-SAME: tensor<bf16>
  print_ir(jnp.bfloat16(0))(lax.sqrt)

  # CHECK-LABEL: TEST: tan float16[]
  # CHECK: mhlo.sine
  # CHECK-SAME: tensor<f32>
  # CHECK: mhlo.cosine
  # CHECK-SAME: tensor<f32>
  print_ir(np.float16(0))(lax.tan)

  # CHECK-LABEL: TEST: tanh float32[]
  # CHECK: mhlo.tanh
  # CHECK-SAME: tensor<f32>
  print_ir(np.float32(0))(lax.tanh)
Ejemplo n.º 3
0
def main(_):
  # CHECK-LABEL: TEST: concatenate bool[2,7] bool[2,5]
  # CHECK: mhlo.concatenate
  # CHECK-SAME: tensor<2x12xi1>
  print_ir([np.empty([2, 7], np.bool_), np.empty([2, 5], np.bool_)])(
      partial(lax.concatenate, dimension=1))

  # CHECK-LABEL: TEST: broadcast_in_dim bool[2,7]
  # CHECK: mhlo.broadcast_in_dim
  # CHECK-SAME: tensor<3x2x5x7x2xi1>
  print_ir(np.empty([2, 7], np.bool_))(
      partial(lax.broadcast_in_dim, shape=(3, 2, 5, 7, 2),
              broadcast_dimensions=(1, 3)))

  # CHECK-LABEL: TEST: iota
  # CHECK: mhlo.iota
  # CHECK-SAME: tensor<10xf32>
  print_ir()(partial(lax.iota, dtype=np.float32, size=10))

  # CHECK-LABEL: TEST: pad int32[2,7]
  # CHECK: mhlo.pad
  # CHECK-SAME: tensor<11x52xi32>
  print_ir(np.empty([2, 7], np.int32))(
      partial(lax.pad, padding_value=np.int32(7),
              padding_config=((2, 3, 4), (4, 5, 6))))

  # CHECK-LABEL: TEST: _reduce_sum int32[2,3,7]
  # CHECK: mhlo.reduce
  # CHECK: mhlo.add
  # CHECK: tensor<3xi32>
  print_ir(np.empty([2, 3, 7], np.int32))(
      partial(lax_internal._reduce_sum, axes=(0, 2)))

  # CHECK-LABEL: TEST: reshape int32[2,3,7]
  # CHECK: mhlo.reshape
  # CHECK-SAME: tensor<42xi32>
  print_ir(np.empty([2, 3, 7], np.int32))(
      partial(lax.reshape, new_sizes=(42,)))

  # CHECK-LABEL: TEST: rev int32[2,7]
  # CHECK: mhlo.rev
  # CHECK-SAME: tensor<2x7xi32>
  print_ir(np.empty([2, 7], np.int32))(
      partial(lax.rev, dimensions=(0, 1)))

  # CHECK-LABEL: TEST: select bool[2,7] int32[2,7] int32[2,7]
  # CHECK: mhlo.select
  # CHECK-SAME: tensor<2x7xi1>
  # CHECK-SAME: tensor<2x7xi32>
  # CHECK-SAME: tensor<2x7xi32>
  print_ir(np.empty([2, 7], np.bool_), np.empty([2, 7], np.int32),
           np.empty([2, 7], np.int32))(lax.select)

  # CHECK-LABEL: TEST: sort int32[2,7]
  # CHECK: mhlo.sort
  # CHECK: tensor<2x7xi32>
  print_ir(np.empty([2, 7], np.int32))(lax.sort)

  # CHECK-LABEL: TEST: squeeze int32[2,1,7]
  # CHECK: mhlo.reshape
  # CHECK-SAME: tensor<2x7xi32>
  print_ir(np.empty([2, 1, 7], np.int32))(
      partial(lax.squeeze, dimensions=(1,)))

  # CHECK-LABEL: TEST: top_k int32[2,7]
  # CHECK: chlo.top_k
  # CHECK: tensor<2x7xi32>
  print_ir(np.empty([2, 7], np.int32))(partial(lax.top_k, k=7))

  # CHECK-LABEL: TEST: transpose int32[2,7]
  # CHECK: mhlo.transpose
  # CHECK-SAME: tensor<7x2xi32>
  print_ir(np.empty([2, 7], np.int32))(
      partial(lax.transpose, permutation=(1, 0)))
Ejemplo n.º 4
0
def main(_):
    # CHECK-LABEL: TEST: bitwise_not bool[7]
    # CHECK: mhlo.not
    # CHECK-SAME: tensor<7xi1>
    print_ir(np.empty([7], np.bool_))(lax.bitwise_not)

    # CHECK-LABEL: TEST: neg int8[]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<i8>
    print_ir(np.int8(0))(lax.neg)

    # CHECK-LABEL: TEST: neg int16[0]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<0xi16>
    print_ir(np.empty([0], np.int16))(lax.neg)

    # CHECK-LABEL: TEST: neg int32[2,3]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<2x3xi32>
    print_ir(np.empty([2, 3], np.int32))(lax.neg)

    # CHECK-LABEL: TEST: neg int64[2,3,4]
    # CHECK: mhlo.negate
    # CHECK-SAME: tensor<2x3x4xi64>
    print_ir(np.empty([2, 3, 4], np.int64))(lax.neg)

    # CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<4x0x1xui8>
    print_ir(np.empty([4, 0, 1], np.uint8), np.empty([4, 0, 1],
                                                     np.uint8))(lax.add)

    # CHECK-LABEL: TEST: add uint16[] uint16[]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<ui16>
    print_ir(np.uint16(0), np.uint16(0))(lax.add)

    # CHECK-LABEL: TEST: add uint32[] uint32[]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<ui32>
    print_ir(np.uint32(0), np.uint32(0))(lax.add)

    # CHECK-LABEL: TEST: add uint64[] uint64[]
    # CHECK: mhlo.add
    # CHECK-SAME: tensor<ui64>
    print_ir(np.uint64(0), np.uint64(0))(lax.add)

    # CHECK-LABEL: TEST: sin float16[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<f16>
    print_ir(np.float16(0))(lax.sin)

    # CHECK-LABEL: TEST: sin bfloat16[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<bf16>
    print_ir(jnp.bfloat16(0))(lax.sin)

    # CHECK-LABEL: TEST: sin float32[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<f32>
    print_ir(np.float32(0))(lax.sin)

    # CHECK-LABEL: TEST: sin float64[]
    # CHECK: mhlo.sine
    # CHECK-SAME: tensor<f64>
    print_ir(np.float64(0))(lax.sin)

    # CHECK-LABEL: TEST: cos complex64[]
    # CHECK: mhlo.cosine
    # CHECK-SAME: tensor<complex<f32>>
    print_ir(np.complex64(0))(lax.cos)

    # CHECK-LABEL: TEST: cos complex128[]
    # CHECK: mhlo.cosine
    # CHECK-SAME: tensor<complex<f64>>
    print_ir(np.complex128(0))(lax.cos)