Example #1
0
    def _compute_new_c(self, state0: NestedMap, i_i: JTensor, i_g: JTensor,
                       f_g: JTensor) -> JTensor:
        asserts.none(i_g)
        forget_gate = jax.nn.sigmoid(f_g) * state0.c

        tanh_i_i = jnp.tanh(i_i)
        input_gate = tanh_i_i - tanh_i_i * jax.nn.sigmoid(f_g)
        return forget_gate + input_gate
Example #2
0
 def test_none_raises(self, value):
     with self.assertRaisesRegex(ValueError,
                                 f'`value={value}` must be `None`.$'):
         asserts.none(value)
     with self.assertRaisesRegex(
             ValueError, f'`custom_value={value}` must be `None`.$'):
         asserts.none(value, value_str=f'custom_value={value}')
     custom_error_msg = 'This is a custom error message.'
     with self.assertRaisesRegex(ValueError, f'{custom_error_msg}$'):
         asserts.none(value, msg=custom_error_msg)
Example #3
0
 def test_none(self):
     value = None
     asserts.none(value)