Ejemplo n.º 1
0
 def testCountNonzero(self, shape, dtype, axis):
   rng = jtu.rand_some_zero()
   onp_fun = lambda x: onp.count_nonzero(x, axis)
   lnp_fun = lambda x: lnp.count_nonzero(x, axis)
   args_maker = lambda: [rng(shape, dtype)]
   self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
Ejemplo n.º 2
0
    op_record("floor_divide", 2, default_dtypes, jtu.rand_nonzero(), ["rev"]),
    op_record("isclose", 2, float_dtypes, jtu.rand_small_positive(), []),
    op_record("log1p",
              1,
              numeric_dtypes,
              jtu.rand_positive(), [],
              test_name="log1p_large"),
    op_record("log1p", 1, numeric_dtypes, jtu.rand_small_positive(), []),
    op_record("logaddexp", 2, float_dtypes, jtu.rand_default(), ["rev"]),
    op_record("ravel", 1, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("remainder", 2, default_dtypes, jtu.rand_nonzero(), []),
    op_record("sinh", 1, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("sqrt", 1, default_dtypes, jtu.rand_positive(), ["rev"]),
    op_record("transpose", 1, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("true_divide", 2, default_dtypes, jtu.rand_nonzero(), ["rev"]),
    op_record("where", 3, (onp.float32, onp.int64), jtu.rand_some_zero(), []),
]

JAX_BITWISE_OP_RECORDS = [
    op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, jtu.rand_bool(),
              []),
    op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, jtu.rand_bool(),
              []),
    op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, jtu.rand_bool(),
              []),
    op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, jtu.rand_bool(),
              []),
]

JAX_REDUCER_RECORDS = [
    op_record("all", 1, bool_dtypes, jtu.rand_default(), []),
Ejemplo n.º 3
0
    op_record("floor_divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("outer", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("isclose", 2, float_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("log2", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("log10", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("log1p", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), [],
              test_name="log1p_large"),
    op_record("log1p", 1, numeric_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("logaddexp", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("logaddexp2", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("ravel", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
    op_record("sqrt", 1, default_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("transpose", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("true_divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("where", 3, (onp.float32, onp.int64), all_shapes, jtu.rand_some_zero(), []),
]

JAX_BITWISE_OP_RECORDS = [
    op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
    op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
    op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
    op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
]

JAX_REDUCER_RECORDS = [
    op_record("mean", 1, default_dtypes, nonempty_shapes, jtu.rand_default(), []),
Ejemplo n.º 4
0
    op_record("logaddexp", 2, float_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("logaddexp2", 2, float_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("ravel", 1, default_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero(),
              []),
    op_record("sqrt", 1, default_dtypes, all_shapes, jtu.rand_positive(),
              ["rev"]),
    op_record("transpose", 1, default_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("true_divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(),
              ["rev"]),
    op_record("where", 3, (onp.float32, onp.int64), all_shapes,
              jtu.rand_some_zero(), []),
]

JAX_BITWISE_OP_RECORDS = [
    op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
    op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
    op_record("bitwise_or", 2, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
    op_record("bitwise_xor", 2, int_dtypes + unsigned_dtypes, all_shapes,
              jtu.rand_bool(), []),
]

JAX_REDUCER_RECORDS = [
    op_record("mean", 1, default_dtypes, nonempty_shapes, jtu.rand_default(),