Ejemplo n.º 1
0
    op_record("tanh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("arcsin", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccos", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan2", 2, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arcsinh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccosh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctanh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
]

JAX_COMPOUND_OP_RECORDS = [
    op_record("divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("exp2", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("expm1", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), [],
              test_name="expm1_large"),
    op_record("expm1", 1, numeric_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("floor_divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("outer", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("isclose", 2, float_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("log2", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("log10", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("log1p", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), [],
              test_name="log1p_large"),
    op_record("log1p", 1, numeric_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("logaddexp", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("logaddexp2", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("ravel", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
    op_record("sqrt", 1, default_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("transpose", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("true_divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
Ejemplo n.º 2
0
numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes

OpRecord = collections.namedtuple(
    "OpRecord",
    ["name", "nargs", "dtypes", "rng", "test_autodiff", "test_name"])


def op_record(name, nargs, dtypes, rng, test_grad, test_name=None):
    test_name = test_name or name
    return OpRecord(name, nargs, dtypes, rng, test_grad, test_name)


JAX_SPECIAL_FUNCTION_RECORDS = [
    # TODO: digamma has no JVP implemented.
    op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True),
    # TODO: gammaln has slightly high error.
    op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False),
    op_record("log_ndtr", 1, float_dtypes, jtu.rand_small(), True),
    op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0., 1.), True),
    op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True),
]

CombosWithReplacement = itertools.combinations_with_replacement


class LaxBackedScipyTests(jtu.JaxTestCase):
Ejemplo n.º 3
0
              ["rev"]),
]

JAX_COMPOUND_OP_RECORDS = [
    op_record("divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(),
              ["rev"]),
    op_record("exp2", 1, numeric_dtypes, all_shapes, jtu.rand_default(),
              ["rev"]),
    op_record("expm1",
              1,
              numeric_dtypes,
              all_shapes,
              jtu.rand_positive(), [],
              test_name="expm1_large"),
    op_record("expm1", 1, numeric_dtypes, all_shapes,
              jtu.rand_small_positive(), []),
    op_record("floor_divide", 2, default_dtypes, all_shapes,
              jtu.rand_nonzero(), ["rev"]),
    op_record("isclose", 2, float_dtypes, all_shapes,
              jtu.rand_small_positive(), []),
    op_record("log2", 1, numeric_dtypes, all_shapes, jtu.rand_positive(),
              ["rev"]),
    op_record("log10", 1, numeric_dtypes, all_shapes, jtu.rand_positive(),
              ["rev"]),
    op_record("log1p",
              1,
              numeric_dtypes,
              all_shapes,
              jtu.rand_positive(), [],
              test_name="log1p_large"),
    op_record("log1p", 1, numeric_dtypes, all_shapes,
Ejemplo n.º 4
0
    op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("arcsin", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccos", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
]

JAX_COMPOUND_OP_RECORDS = [
    op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive(), [],
              test_name="expm1_large"),
    op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
    op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default(), []),
    op_record("isclose", 2, all_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive(), [],
              test_name="log1p_large"),
    op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("logaddexp", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("logaddexp2", 2, float_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("polyval", 2, number_dtypes, nonempty_nonscalar_array_shapes, jtu.rand_default(), []),
    op_record("ravel", 1, all_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("remainder", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), []),
    op_record("square", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
Ejemplo n.º 5
0
default_dtypes = float_dtypes + int_dtypes
numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes

OpRecord = collections.namedtuple(
    "OpRecord", ["name", "nargs", "dtypes", "rng", "diff_modes", "test_name"])


def op_record(name, nargs, dtypes, rng, diff_modes, test_name=None):
    test_name = test_name or name
    return OpRecord(name, nargs, dtypes, rng, diff_modes, test_name)


JAX_SPECIAL_FUNCTION_RECORDS = [
    op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), ["rev"]),
    op_record("digamma", 1, float_dtypes, jtu.rand_positive(), []),
    op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
    op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), ["rev"]),
]

CombosWithReplacement = itertools.combinations_with_replacement


class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]