def test_unary_cwise_real_ops_1(self): real_ops = [ lambda x: math_ops.acosh(1 + math_ops.square(x)), math_ops.abs, math_ops.acos, math_ops.asin, math_ops.asinh, math_ops.atan, math_ops.atanh, math_ops.bessel_i0e, math_ops.bessel_i1e, math_ops.cos, math_ops.cosh, math_ops.digamma, math_ops.erf, math_ops.erfc, math_ops.erfinv, math_ops.exp, math_ops.expm1, math_ops.inv, math_ops.is_finite, math_ops.is_inf, math_ops.lgamma, math_ops.log, math_ops.log1p, math_ops.ndtri, ] self._test_unary_cwise_ops(real_ops, False)
def test_unary_cwise_real_ops_1(self): real_ops = [ lambda x: math_ops.acosh(1 + math_ops.square(x)), math_ops.abs, math_ops.acos, math_ops.asin, math_ops.asinh, math_ops.atan, math_ops.atanh, math_ops.bessel_i0e, math_ops.bessel_i1e, math_ops.cos, math_ops.cosh, math_ops.digamma, math_ops.erf, math_ops.erfc, math_ops.exp, math_ops.expm1, math_ops.inv, math_ops.is_finite, math_ops.is_inf, math_ops.lgamma, math_ops.log, math_ops.log1p, ] self._test_unary_cwise_ops(real_ops, False)
def test_unary_cwise_real_ops_1(self): if test.is_built_with_rocm(): # TODO(rocm): # This fails on ROCm...see JIRA ticket 236756 self.skipTest("Fails on ROCM") real_ops = [ lambda x: math_ops.acosh(1 + math_ops.square(x)), math_ops.abs, math_ops.acos, math_ops.asin, math_ops.asinh, math_ops.atan, math_ops.atanh, math_ops.cos, math_ops.cosh, math_ops.digamma, math_ops.erf, math_ops.erfc, math_ops.erfinv, math_ops.exp, math_ops.expm1, math_ops.inv, math_ops.is_finite, math_ops.is_inf, math_ops.lgamma, math_ops.log, math_ops.log1p, math_ops.ndtri, special_math_ops.bessel_i0e, special_math_ops.bessel_i1e, ] self._test_unary_cwise_ops(real_ops, False)
def _generate_unary_cwise_math_cases(): # TODO(rachelim): Consolidate tests with pfor when APIs are somewhat shared. bitwise_cases = [("Invert", bitwise_ops.invert)] logical_cases = [("LogicalNot", math_ops.logical_not)] complex_cases = [ ("Angle", math_ops.angle), ("ComplexAbs", math_ops.abs), ("Conj", math_ops.conj), ("Imag", math_ops.imag), ("Real", math_ops.real), ] real_cases = [ ("Abs", math_ops.abs), ("Acos", math_ops.acos), ("Acosh", lambda x: math_ops.acosh(1 + math_ops.square(x))), ("Asin", math_ops.asin), ("Asinh", math_ops.asinh), ("Atan", math_ops.atan), ("Atanh", math_ops.atanh), ("BesselI0e", math_ops.bessel_i0e), ("BesselI1e", math_ops.bessel_i1e), ("Ceil", math_ops.ceil), ("Cos", math_ops.cos), ("Cosh", math_ops.cosh), ("Digamma", math_ops.digamma), ("Elu", nn.elu), ("Erf", math_ops.erf), ("Erfc", math_ops.erfc), ("Exp", math_ops.exp), ("Expm1", math_ops.expm1), ("Floor", math_ops.floor), ("Inv", math_ops.inv), ("IsFinite", math_ops.is_finite), ("IsInf", math_ops.is_inf), ("Lgamma", math_ops.lgamma), ("Log", math_ops.log), ("Log1p", math_ops.log1p), ("Neg", math_ops.negative), ("Reciprocal", math_ops.reciprocal), ("Relu", nn.relu), ("Relu6", nn.relu6), ("Rint", math_ops.rint), ("Round", math_ops.round), ("Rsqrt", math_ops.rsqrt), ("Selu", nn.selu), ("Sigmoid", math_ops.sigmoid), ("Sign", math_ops.sign), ("Sin", math_ops.sin), ("Sinh", math_ops.sinh), ("Softplus", nn.softplus), ("Softsign", nn.softsign), ("Sqrt", math_ops.sqrt), ("Square", math_ops.square), ("Tan", math_ops.tan), ("Tanh", math_ops.tanh), ] random_input = np.random.rand(3, 5) complex_component = np.random.rand(3, 5) random_int = np.random.randint(0, 10, (7, 3, 5)) def bitwise_dataset_factory(): return dataset_ops.Dataset.from_tensor_slices(random_int) def logical_dataset_factory(): return dataset_ops.Dataset.from_tensor_slices(random_input > 0) def random_dataset_factory(): return dataset_ops.Dataset.from_tensor_slices(random_input) def complex_dataset_factory(): return dataset_ops.Dataset.from_tensor_slices( math_ops.complex(random_input, complex_component)) case_factory_pairs = [ (bitwise_cases, bitwise_dataset_factory), (logical_cases, logical_dataset_factory), (complex_cases, complex_dataset_factory), (real_cases, random_dataset_factory), ] return [(case[0], case[1], factory) for cases, factory in case_factory_pairs for case in cases]
def safe_acosh(x): return math_ops.acosh(1 + math_ops.square(x))
def test_unary_cwise_ops(self): complex_ops = [ math_ops.angle, math_ops.imag, math_ops.complex_abs, math_ops.real, math_ops.conj, ] real_ops = [ lambda x: math_ops.acosh(1 + math_ops.square(x)), math_ops.abs, math_ops.acos, math_ops.asin, math_ops.asinh, math_ops.atan, math_ops.atanh, math_ops.bessel_i0e, math_ops.bessel_i1e, math_ops.cos, math_ops.cosh, math_ops.digamma, math_ops.erf, math_ops.erfc, math_ops.exp, math_ops.expm1, math_ops.inv, math_ops.is_finite, math_ops.is_inf, math_ops.lgamma, math_ops.log, math_ops.log1p, math_ops.neg, math_ops.negative, math_ops.reciprocal, math_ops.rint, math_ops.round, math_ops.rsqrt, math_ops.sigmoid, math_ops.sign, math_ops.sin, math_ops.sinh, math_ops.sqrt, math_ops.square, math_ops.tan, math_ops.tanh, math_ops.tanh, nn.elu, nn.relu, nn.relu6, nn.selu, nn.softplus, nn.softsign, ] for op in complex_ops + real_ops: with backprop.GradientTape(persistent=True) as g: x = random_ops.random_uniform([3, 5]) g.watch(x) if op in complex_ops: y = random_ops.random_uniform([3, 5]) g.watch(y) x = math_ops.complex(x, y) # pylint: disable=cell-var-from-loop output_dtypes = [] def loop_fn(i): with g: x1 = array_ops.gather(x, i) y1 = op(x1) outputs = [op(x), y1] if y1.dtype == dtypes.float32: loss = math_ops.reduce_sum(y1 * y1) else: loss = None if loss is not None: grad = g.gradient(loss, x1) if grad is not None: outputs.append(grad) del output_dtypes[:] output_dtypes.extend([t.dtype for t in outputs]) return outputs # pylint: enable=cell-var-from-loop self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)