示例#1
0
文件: optix_test.py 项目: yxd886/jax
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # experimental/optix.py sgd
        optix_sgd_params = self.init_params
        sgd = optix.sgd(LR, 0.0)
        state_sgd = sgd.init(optix_sgd_params)

        # experimental/optix.py sgd apply every
        optix_sgd_apply_every_params = self.init_params
        sgd_apply_every = optix.chain(optix.apply_every(k=k),
                                      optix.trace(decay=0, nesterov=False),
                                      optix.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optix_sgd_apply_every_params)
        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optix_sgd_params = optix.apply_updates(optix_sgd_params,
                                                   updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = sgd_apply_every.update(
                self.per_step_updates, state_sgd_apply_every)
            optix_sgd_apply_every_params = optix.apply_updates(
                optix_sgd_apply_every_params, updates_sgd_apply_every)
            if i % k == k - 1:
                # Check equivalence.
                for x, y in zip(tree_leaves(optix_sgd_apply_every_params),
                                tree_leaves(optix_sgd_params)):
                    np.testing.assert_allclose(x, y, atol=1e-6, rtol=100)
            else:
                # Check updaue is zero.
                for x, y in zip(tree_leaves(updates_sgd_apply_every),
                                tree_leaves(zero_update)):
                    np.testing.assert_allclose(x, y, atol=1e-10, rtol=1e-5)
示例#2
0
def make_optimizer():
    """SGD with nesterov momentum and a custom lr schedule."""
    return optix.chain(
        optix.trace(decay=FLAGS.optimizer_momentum,
                    nesterov=FLAGS.optimizer_use_nesterov),
        optix.scale_by_schedule(lr_schedule), optix.scale(-1))
示例#3
0
文件: fit.py 项目: gehring/neos
def get_solvers(model_constructor,
                pdf_transform=False,
                default_rtol=1e-10,
                default_atol=1e-10,
                default_max_iter=int(1e7),
                learning_rate=0.01):
    '''
    Wraps a series of functions that perform maximum likelihood fitting in the
    `two_phase_solve` method found in the `fax` python module. This allows for
    the calculation of gradients of the best-fit parameters with respect to upstream
    parameters that control the underlying model, i.e. the event yields (which are
    then parameterized by weights or similar).

    Args:
            model_constructor: Function that takes in the parameters of the observable,
            and returns a model object (and background-only parameters)
    Returns:
            g_fitter, c_fitter: Callable functions that perform global and constrained fits
            respectively. Differentiable :)
    '''

    gradient_descent = optix.scale(-1e-2)

    def make_model(hyper_pars):
        constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]
        m, bonlypars = model_constructor(nn_pars)

        bounds = m.config.suggested_bounds()
        constrained_mu = to_inf(constrained_mu,
                                bounds[0]) if pdf_transform else constrained_mu

        exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True)

        def expected_logpdf(
                pars):  # maps pars to bounded space if pdf_transform = True

            return (m.logpdf(to_bounded_vec(pars, bounds), exp_bonly_data)
                    if pdf_transform else m.logpdf(pars, exp_bonly_data))

        def global_fit_objective(pars):  # NLL
            return -expected_logpdf(pars)[0]

        def constrained_fit_objective(nuis_par):  # NLL
            pars = jax.numpy.concatenate(
                [jax.numpy.asarray([constrained_mu]), nuis_par])
            return -expected_logpdf(pars)[0]

        return constrained_mu, global_fit_objective, constrained_fit_objective, bounds

    def global_bestfit_minimized(hyper_param):
        _, nll, _, _ = make_model(hyper_param)

        def bestfit_via_grad_descent(param):  # gradient descent
            g = jax.grad(nll)(param)
            updates, _ = gradient_descent.update(g,
                                                 gradient_descent.init(param))
            return optix.apply_updates(param, updates)

        return bestfit_via_grad_descent

    def constrained_bestfit_minimized(hyper_param):
        mu, nll, cnll, bounds = make_model(hyper_param)

        def bestfit_via_grad_descent(param):  # gradient descent
            _, np = param[0], param[1:]
            g = jax.grad(cnll)(np)
            updates, _ = gradient_descent.update(g, gradient_descent.init(np))
            np = optix.apply_updates(np, updates)
            param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])
            return param

        return bestfit_via_grad_descent

    convergence_test = twophase.default_convergence_test(
        rtol=default_rtol,
        atol=default_atol,
    )
    global_solver = twophase.default_solver(
        convergence_test=convergence_test,
        max_iter=default_max_iter,
    )
    constrained_solver = global_solver

    def g_fitter(init, hyper_pars):
        return twophase.two_phase_solve(
            global_bestfit_minimized,
            init,
            hyper_pars,
            solvers=(global_solver, ),
        )

    def c_fitter(init, hyper_pars):
        return twophase.two_phase_solve(
            constrained_bestfit_minimized,
            init,
            hyper_pars,
            solvers=(constrained_solver, ),
        )

    return g_fitter, c_fitter
示例#4
0
def make_optimizer(lr_schedule, momentum_decay):
  return optix.chain(optix.trace(decay=momentum_decay, nesterov=False),
                     optix.scale_by_schedule(lr_schedule),
                     optix.scale(-1))