Example #1
0
    def test_apply_if_finite(self, opt_builder):
        one = jnp.ones([])
        nan = jnp.array(jnp.nan)

        def fn(x):
            return x * hk.get_parameter('p', [],
                                        init=hk.initializers.Constant(0.))

        fn = hk.without_apply_rng(hk.transform(fn))
        params = fn.init(jax.random.PRNGKey(1905), one)
        opt = wrappers.apply_if_finite(opt_builder(), 2)
        state = opt.init(params)
        grads_fn = jax.grad(self.variant(fn.apply))
        # Do one successful param update
        grads = grads_fn(params, one)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        # We know exactly what should be the value of params since we are
        # effectively using sgd in all cases.
        self.assertEqual(-1., float(jax.tree_flatten(params)[0][0]))
        self.assertTrue(bool(state.last_finite))
        # Check 2 rejected param updates
        for step in range(2):
            grads = grads_fn(params, nan)
            updates, state = opt.update(grads, state, params)
            params = update.apply_updates(params, updates)
            self.assertEqual(-1., float(jax.tree_flatten(params)[0][0]))
            self.assertFalse(bool(state.last_finite))
            self.assertEqual(step + 1, int(state.notfinite_count))
        # Next successful param update
        grads = grads_fn(params, one)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        self.assertEqual(-2., float(jax.tree_flatten(params)[0][0]))
        self.assertTrue(bool(state.last_finite))
        # Again 2 rejected param updates
        for step in range(2):
            grads = grads_fn(params, nan)
            updates, state = opt.update(grads, state, params)
            params = update.apply_updates(params, updates)
            self.assertEqual(-2., float(jax.tree_flatten(params)[0][0]))
            self.assertFalse(bool(state.last_finite))
            self.assertEqual(step + 1, int(state.notfinite_count))
        # Next param update with NaN is accepted since we reached maximum
        grads = grads_fn(params, nan)
        updates, state = opt.update(grads, state, params)
        params = update.apply_updates(params, updates)
        self.assertTrue(bool(jnp.isnan(jax.tree_flatten(params)[0][0])))
        self.assertEqual(5, int(state.total_notfinite))
Example #2
0
    def test_apply_if_finite_pmap(self):
        # Unlike in `test_apply_if_finite`:
        # * pmap is applied to the gradient computation and the optimisation;
        # * the NaNs are caused inside the function and do not come from the inputs.
        half = jnp.ones([1]) / 2.
        two = jnp.ones([1]) * 2.  # Causes a NaN in arctanh

        def fn(x):
            return jnp.arctanh(x) * hk.get_parameter(
                'p', [], init=hk.initializers.Constant(0.))

        fn = hk.without_apply_rng(hk.transform(fn))

        opt = wrappers.apply_if_finite(alias.sgd(1.), 2)

        def fn_update(params, opt_state, x):
            grads = jax.grad(fn.apply)(params, x)
            grads = jax.lax.psum(grads, axis_name='i')
            updates, new_opt_state = opt.update(grads, opt_state, params)
            new_params = update.apply_updates(params, updates)
            return new_params, new_opt_state

        fn_update = jax.pmap(fn_update, axis_name='i')

        params = fn.init(jax.random.PRNGKey(1905), half)
        opt_state = opt.init(params)
        params = jax.tree_map(lambda x: x[None], params)
        opt_state = jax.tree_map(lambda x: x[None], opt_state)
        # Do one successful param update
        params, opt_state = fn_update(params, opt_state, half)
        self.assertTrue(bool(opt_state.last_finite))
        # Check 2 rejected param updates
        for step in range(2):
            params, opt_state = fn_update(params, opt_state, two)
            self.assertFalse(bool(opt_state.last_finite))
            self.assertEqual(step + 1, int(opt_state.notfinite_count))
        # Next successful param update
        params, opt_state = fn_update(params, opt_state, half)
        self.assertTrue(bool(opt_state.last_finite))
        # Again 2 rejected param updates
        for step in range(2):
            params, opt_state = fn_update(params, opt_state, two)
            self.assertFalse(bool(opt_state.last_finite))
            self.assertEqual(step + 1, int(opt_state.notfinite_count))
        # Next param update with NaN is accepted since we reached maximum
        params, opt_state = fn_update(params, opt_state, two)
        self.assertEqual(5, int(opt_state.total_notfinite))