def test_apply_if_finite(self, opt_builder): one = jnp.ones([]) nan = jnp.array(jnp.nan) def fn(x): return x * hk.get_parameter('p', [], init=hk.initializers.Constant(0.)) fn = hk.without_apply_rng(hk.transform(fn)) params = fn.init(jax.random.PRNGKey(1905), one) opt = wrappers.apply_if_finite(opt_builder(), 2) state = opt.init(params) grads_fn = jax.grad(self.variant(fn.apply)) # Do one successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) # We know exactly what should be the value of params since we are # effectively using sgd in all cases. self.assertEqual(-1., float(jax.tree_flatten(params)[0][0])) self.assertTrue(bool(state.last_finite)) # Check 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-1., float(jax.tree_flatten(params)[0][0])) self.assertFalse(bool(state.last_finite)) self.assertEqual(step + 1, int(state.notfinite_count)) # Next successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_flatten(params)[0][0])) self.assertTrue(bool(state.last_finite)) # Again 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_flatten(params)[0][0])) self.assertFalse(bool(state.last_finite)) self.assertEqual(step + 1, int(state.notfinite_count)) # Next param update with NaN is accepted since we reached maximum grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertTrue(bool(jnp.isnan(jax.tree_flatten(params)[0][0]))) self.assertEqual(5, int(state.total_notfinite))
def test_apply_if_finite_pmap(self): # Unlike in `test_apply_if_finite`: # * pmap is applied to the gradient computation and the optimisation; # * the NaNs are caused inside the function and do not come from the inputs. half = jnp.ones([1]) / 2. two = jnp.ones([1]) * 2. # Causes a NaN in arctanh def fn(x): return jnp.arctanh(x) * hk.get_parameter( 'p', [], init=hk.initializers.Constant(0.)) fn = hk.without_apply_rng(hk.transform(fn)) opt = wrappers.apply_if_finite(alias.sgd(1.), 2) def fn_update(params, opt_state, x): grads = jax.grad(fn.apply)(params, x) grads = jax.lax.psum(grads, axis_name='i') updates, new_opt_state = opt.update(grads, opt_state, params) new_params = update.apply_updates(params, updates) return new_params, new_opt_state fn_update = jax.pmap(fn_update, axis_name='i') params = fn.init(jax.random.PRNGKey(1905), half) opt_state = opt.init(params) params = jax.tree_map(lambda x: x[None], params) opt_state = jax.tree_map(lambda x: x[None], opt_state) # Do one successful param update params, opt_state = fn_update(params, opt_state, half) self.assertTrue(bool(opt_state.last_finite)) # Check 2 rejected param updates for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) self.assertEqual(step + 1, int(opt_state.notfinite_count)) # Next successful param update params, opt_state = fn_update(params, opt_state, half) self.assertTrue(bool(opt_state.last_finite)) # Again 2 rejected param updates for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) self.assertEqual(step + 1, int(opt_state.notfinite_count)) # Next param update with NaN is accepted since we reached maximum params, opt_state = fn_update(params, opt_state, two) self.assertEqual(5, int(opt_state.total_notfinite))