op_record("ceil", 1, float_dtypes, jtu.rand_default(), []), op_record("conj", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("conjugate", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("exp", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("floor", 1, float_dtypes, jtu.rand_default(), []), op_record("greater", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("greater_equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("less", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("less_equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("log", 1, numeric_dtypes, jtu.rand_positive(), ["rev"]), op_record("logical_and", 2, default_dtypes, jtu.rand_bool(), []), op_record("logical_not", 1, default_dtypes, jtu.rand_bool(), []), op_record("logical_or", 2, default_dtypes, jtu.rand_bool(), []), op_record("logical_xor", 2, default_dtypes, jtu.rand_bool(), []), op_record("maximum", 2, default_dtypes, jtu.rand_some_inf(), []), op_record("minimum", 2, default_dtypes, jtu.rand_some_inf(), []), op_record("multiply", 2, default_dtypes, jtu.rand_default(), ["rev"]), op_record("negative", 1, default_dtypes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, default_dtypes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, float_dtypes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, default_dtypes, jtu.rand_default(), ["rev"]), op_record("tanh", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("sin", 1, default_dtypes, jtu.rand_default(), ["rev"]), op_record("cos", 1, default_dtypes, jtu.rand_default(), ["rev"]), ] JAX_COMPOUND_OP_RECORDS = [ op_record("cosh", 1, default_dtypes, jtu.rand_default(), ["rev"]), op_record("divide", 2, default_dtypes, jtu.rand_nonzero(), ["rev"]), op_record("expm1",
def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None): test_name = test_name or name return OpRecord(name, nargs, dtypes, shapes, rng, diff_modes, test_name) JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("abs", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("add", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("conj", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("conjugate", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("exp", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("greater", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("greater_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("isfinite", 1, numeric_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("less", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("less_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("log", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("logical_and", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_not", 1, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_or", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_xor", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("maximum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("minimum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("multiply", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("negative", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, float_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sin", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
test_name = test_name or name return OpRecord(name, nargs, dtypes, shapes, rng, diff_modes, test_name) JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("abs", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("add", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("conj", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("conjugate", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("equal", 2, all_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("exp", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("fabs", 1, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []), op_record("greater", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("greater_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("isfinite", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("less", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []), op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),