コード例 #1
0
ファイル: base.py プロジェクト: lindermanlab/jxf
        def nonconjugate_m_step(expectations, nonconjugate_params,
                                conjugate_params):
            # M step: optimize the non-conjugate parameters via gradient methods.
            def objective(params):
                nonconjugate_params = cls.nonconj_params_from_unconstrained(
                    params, **kwargs)
                lp = 0
                num_datapoints = 0
                for expects, data_dict, these_weights in zip(
                        expectations, dataset, weights):
                    _lp = cls.expected_log_prob(nonconjugate_params,
                                                conjugate_params,
                                                expectations=expects,
                                                **data_dict,
                                                **kwargs)
                    lp += np.sum(these_weights * _lp)
                    num_datapoints += np.sum(these_weights)
                return -lp / num_datapoints

            # Optimize with Nesterov's accelerated gradient
            opt_init, opt_update, get_params = nesterov(
                nesterov_step_size, nesterov_mass)

            def check_convergence(state):
                itr, _, (prev_val, curr_val) = state
                return (abs(curr_val - prev_val) >
                        nesterov_threshold) * (itr < nesterov_max_iters)

            def step(state):
                itr, opt_state, (_, prev_val) = state
                curr_val, grads = value_and_grad(objective)(
                    get_params(opt_state))
                opt_state = opt_update(itr, grads, opt_state)
                return (itr + 1, opt_state, (prev_val, curr_val))

            # Initialize and run the optimizer
            init_params = cls.nonconj_params_to_unconstrained(
                nonconjugate_params, **kwargs)
            init_state = (0, opt_init(init_params), (np.inf,
                                                     objective(init_params)))
            final_state = lax.while_loop(check_convergence, step, init_state)

            # Unpack the final state
            itr_count, params, lp = final_state[0], get_params(
                final_state[1]), -1 * final_state[2][1]
            if verbosity >= Verbosity.LOUD:
                print("Nesterov converged in ", itr_count, "iterations")

            return cls.nonconj_params_from_unconstrained(params, **kwargs), lp
コード例 #2
0
ファイル: qnnops.py プロジェクト: phyjoon/circuit_comparison
def get_optimizer(name, optim_args, scheduler):
    name = name.lower()
    if optim_args and isinstance(optim_args, str):
        optim_args = [kv.split(':') for kv in optim_args.split(',')]
        optim_args = {k: float(v) for k, v in optim_args}
    optim_args = optim_args or {}
    if name == 'adam':
        init_fun, update_fun, get_params = optimizers.adam(
            scheduler, **optim_args)
    elif name == 'nesterov':
        if 'mass' not in optim_args:
            optim_args['mass'] = 0.1
        init_fun, update_fun, get_params = optimizers.nesterov(
            scheduler, **optim_args)
    else:
        raise ValueError(f'An optimizer {name} is not supported. ')
    print(f'Loaded an optimization {name} - {optim_args}')
    return init_fun, update_fun, get_params
コード例 #3
0
    def get_optimizer(self, optim=None, stage='learn', step_size=None):

        if optim is None:
            if stage == 'learn':
                optim = self.optim_learn
            else:
                optim = self.optim_proj
        if step_size is None:
            step_size = self.step_size

        if optim == 1:
            if self.verb > 2:
                print("With momentum optimizer")
            opt_init, opt_update, get_params = momentum(step_size=step_size,
                                                        mass=0.95)
        elif optim == 2:
            if self.verb > 2:
                print("With rmsprop optimizer")
            opt_init, opt_update, get_params = rmsprop(step_size,
                                                       gamma=0.9,
                                                       eps=1e-8)
        elif optim == 3:
            if self.verb > 2:
                print("With adagrad optimizer")
            opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9)
        elif optim == 4:
            if self.verb > 2:
                print("With Nesterov optimizer")
            opt_init, opt_update, get_params = nesterov(step_size, 0.9)
        elif optim == 5:
            if self.verb > 2:
                print("With SGD optimizer")
            opt_init, opt_update, get_params = sgd(step_size)
        else:
            if self.verb > 2:
                print("With adam optimizer")
            opt_init, opt_update, get_params = adam(step_size)

        return opt_init, opt_update, get_params
コード例 #4
0
ファイル: vae-stax.py プロジェクト: adambozson/vae-jax
        Relu,
        Dense(512),
        Relu,
        FanOut(2),
        stax.parallel(Dense(latent_dim), Dense(latent_dim)),
    )

    decoder_init, decode = stax.serial(
        Dense(512), Relu, Dense(512), Relu, Dense(data.num_pixels)
    )

    step_size = 1e-3
    num_epochs = 100
    batch_size = 128

    opt_init, opt_update, get_params = optimizers.nesterov(step_size, mass=0.9)

    # Initialisation
    key = random.PRNGKey(0)
    enc_init_key, dec_init_key, key = random.split(key, 3)
    _, enc_init_params = encoder_init(enc_init_key, (batch_size, data.num_pixels))
    _, dec_init_params = decoder_init(dec_init_key, (batch_size, latent_dim))
    init_params = (enc_init_params, dec_init_params)

    opt_state = opt_init(init_params)

    @jit
    def update(i, key, opt_state, images):
        loss = lambda p: -elbo(key, p, images)[0] / len(images)
        g = grad(loss)(get_params(opt_state))
        return opt_update(i, g, opt_state)
コード例 #5
0
 def __init__(self, learning_rate, mass=0.9):
     super().__init__(learning_rate)
     self.mass = mass
     self.opt_init, self.opt_update, self.get_params = nesterov(
         step_size=self.lr, mass=self.mass)