Exemple #1
0
OpRecord = collections.namedtuple(
  "OpRecord",
  ["name", "nargs", "dtypes", "shapes", "rng", "diff_modes", "test_name"])


def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
  test_name = test_name or name
  return OpRecord(name, nargs, dtypes, shapes, rng, diff_modes, test_name)

JAX_ONE_TO_ONE_OP_RECORDS = [
    op_record("abs", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("add", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("conj", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("conjugate", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("exp", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("greater", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("greater_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("isfinite", 1, numeric_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("less", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("less_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("log", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("logical_and", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_not", 1, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_or", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_xor", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("maximum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("minimum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("multiply", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
Exemple #2
0
def op_record(name, nargs, dtypes, shapes, rng, diff_modes, test_name=None):
    test_name = test_name or name
    return OpRecord(name, nargs, dtypes, shapes, rng, diff_modes, test_name)


JAX_ONE_TO_ONE_OP_RECORDS = [
    op_record("abs", 1, default_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("add", 2, default_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("ceil", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("conj", 1, numeric_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("conjugate", 1, numeric_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(),
              []),
    op_record("exp", 1, numeric_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("floor", 1, float_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("greater", 2, default_dtypes, all_shapes, jtu.rand_some_equal(),
              []),
    op_record("greater_equal", 2, default_dtypes, all_shapes,
              jtu.rand_some_equal(), []),
    op_record("less", 2, default_dtypes, all_shapes, jtu.rand_some_equal(),
              []),
    op_record("less_equal", 2, default_dtypes, all_shapes,
              jtu.rand_some_equal(), []),
    op_record("log", 1, numeric_dtypes, all_shapes, jtu.rand_positive(),
              ["rev"]),
    op_record("logical_and", 2, default_dtypes, all_shapes, jtu.rand_bool(),