def _compute_new_c(self, state0: NestedMap, i_i: JTensor, i_g: JTensor, f_g: JTensor) -> JTensor: asserts.none(i_g) forget_gate = jax.nn.sigmoid(f_g) * state0.c tanh_i_i = jnp.tanh(i_i) input_gate = tanh_i_i - tanh_i_i * jax.nn.sigmoid(f_g) return forget_gate + input_gate
def test_none_raises(self, value): with self.assertRaisesRegex(ValueError, f'`value={value}` must be `None`.$'): asserts.none(value) with self.assertRaisesRegex( ValueError, f'`custom_value={value}` must be `None`.$'): asserts.none(value, value_str=f'custom_value={value}') custom_error_msg = 'This is a custom error message.' with self.assertRaisesRegex(ValueError, f'{custom_error_msg}$'): asserts.none(value, msg=custom_error_msg)
def test_none(self): value = None asserts.none(value)