def test_apply_every(self): # The frequency of the application of sgd k = 4 zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) # experimental/optix.py sgd optix_sgd_params = self.init_params sgd = optix.sgd(LR, 0.0) state_sgd = sgd.init(optix_sgd_params) # experimental/optix.py sgd apply every optix_sgd_apply_every_params = self.init_params sgd_apply_every = optix.chain(optix.apply_every(k=k), optix.trace(decay=0, nesterov=False), optix.scale(-LR)) state_sgd_apply_every = sgd_apply_every.init( optix_sgd_apply_every_params) for i in range(STEPS): # Apply a step of sgd updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) optix_sgd_params = optix.apply_updates(optix_sgd_params, updates_sgd) # Apply a step of sgd_apply_every updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update( self.per_step_updates, state_sgd_apply_every) optix_sgd_apply_every_params = optix.apply_updates( optix_sgd_apply_every_params, updates_sgd_apply_every) if i % k == k - 1: # Check equivalence. for x, y in zip(tree_leaves(optix_sgd_apply_every_params), tree_leaves(optix_sgd_params)): np.testing.assert_allclose(x, y, atol=1e-6, rtol=100) else: # Check updaue is zero. for x, y in zip(tree_leaves(updates_sgd_apply_every), tree_leaves(zero_update)): np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
def make_optimizer(): """SGD with nesterov momentum and a custom lr schedule.""" return optix.chain( optix.trace(decay=FLAGS.optimizer_momentum, nesterov=FLAGS.optimizer_use_nesterov), optix.scale_by_schedule(lr_schedule), optix.scale(-1))
def get_solvers(model_constructor, pdf_transform=False, default_rtol=1e-10, default_atol=1e-10, default_max_iter=int(1e7), learning_rate=0.01): ''' Wraps a series of functions that perform maximum likelihood fitting in the `two_phase_solve` method found in the `fax` python module. This allows for the calculation of gradients of the best-fit parameters with respect to upstream parameters that control the underlying model, i.e. the event yields (which are then parameterized by weights or similar). Args: model_constructor: Function that takes in the parameters of the observable, and returns a model object (and background-only parameters) Returns: g_fitter, c_fitter: Callable functions that perform global and constrained fits respectively. Differentiable :) ''' gradient_descent = optix.scale(-1e-2) def make_model(hyper_pars): constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1] m, bonlypars = model_constructor(nn_pars) bounds = m.config.suggested_bounds() constrained_mu = to_inf(constrained_mu, bounds[0]) if pdf_transform else constrained_mu exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True) def expected_logpdf( pars): # maps pars to bounded space if pdf_transform = True return (m.logpdf(to_bounded_vec(pars, bounds), exp_bonly_data) if pdf_transform else m.logpdf(pars, exp_bonly_data)) def global_fit_objective(pars): # NLL return -expected_logpdf(pars)[0] def constrained_fit_objective(nuis_par): # NLL pars = jax.numpy.concatenate( [jax.numpy.asarray([constrained_mu]), nuis_par]) return -expected_logpdf(pars)[0] return constrained_mu, global_fit_objective, constrained_fit_objective, bounds def global_bestfit_minimized(hyper_param): _, nll, _, _ = make_model(hyper_param) def bestfit_via_grad_descent(param): # gradient descent g = jax.grad(nll)(param) updates, _ = gradient_descent.update(g, gradient_descent.init(param)) return optix.apply_updates(param, updates) return bestfit_via_grad_descent def constrained_bestfit_minimized(hyper_param): mu, nll, cnll, bounds = make_model(hyper_param) def bestfit_via_grad_descent(param): # gradient descent _, np = param[0], param[1:] g = jax.grad(cnll)(np) updates, _ = gradient_descent.update(g, gradient_descent.init(np)) np = optix.apply_updates(np, updates) param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np]) return param return bestfit_via_grad_descent convergence_test = twophase.default_convergence_test( rtol=default_rtol, atol=default_atol, ) global_solver = twophase.default_solver( convergence_test=convergence_test, max_iter=default_max_iter, ) constrained_solver = global_solver def g_fitter(init, hyper_pars): return twophase.two_phase_solve( global_bestfit_minimized, init, hyper_pars, solvers=(global_solver, ), ) def c_fitter(init, hyper_pars): return twophase.two_phase_solve( constrained_bestfit_minimized, init, hyper_pars, solvers=(constrained_solver, ), ) return g_fitter, c_fitter
def make_optimizer(lr_schedule, momentum_decay): return optix.chain(optix.trace(decay=momentum_decay, nesterov=False), optix.scale_by_schedule(lr_schedule), optix.scale(-1))