def test_unknown_shape(self): fn = lisht.get_concrete_function( tf.TensorSpec(shape=None, dtype=tf.float32)) for shape in [(1, ), (1, 2), (1, 2, 3), (1, 2, 3, 4)]: x = tf.ones(shape=shape, dtype=tf.float32) self.assertAllClose(fn(x), lisht(x))
def verify_funcs_are_equivalent(self, dtype): x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype) x = tf.convert_to_tensor(x_np) with tf.GradientTape(persistent=True) as t: t.watch(x) y_native = lisht(x) y_py = _lisht_py(x) self.assertAllCloseAccordingToType(y_native, y_py) grad_native = t.gradient(y_native, x) grad_py = t.gradient(y_py, x) self.assertAllCloseAccordingToType(grad_native, grad_py)
def test_lisht(self, dtype): x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) expected_result = tf.constant( [1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype) self.assertAllCloseAccordingToType(lisht(x), expected_result)
def test_lisht(dtype): x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) expected_result = tf.constant( [1.9280552, 0.7615942, 0.0, 0.7615942, 1.9280552], dtype=dtype ) test_utils.assert_allclose_according_to_type(lisht(x), expected_result)