def nonconjugate_m_step(expectations, nonconjugate_params, conjugate_params): # M step: optimize the non-conjugate parameters via gradient methods. def objective(params): nonconjugate_params = cls.nonconj_params_from_unconstrained( params, **kwargs) lp = 0 num_datapoints = 0 for expects, data_dict, these_weights in zip( expectations, dataset, weights): _lp = cls.expected_log_prob(nonconjugate_params, conjugate_params, expectations=expects, **data_dict, **kwargs) lp += np.sum(these_weights * _lp) num_datapoints += np.sum(these_weights) return -lp / num_datapoints # Optimize with Nesterov's accelerated gradient opt_init, opt_update, get_params = nesterov( nesterov_step_size, nesterov_mass) def check_convergence(state): itr, _, (prev_val, curr_val) = state return (abs(curr_val - prev_val) > nesterov_threshold) * (itr < nesterov_max_iters) def step(state): itr, opt_state, (_, prev_val) = state curr_val, grads = value_and_grad(objective)( get_params(opt_state)) opt_state = opt_update(itr, grads, opt_state) return (itr + 1, opt_state, (prev_val, curr_val)) # Initialize and run the optimizer init_params = cls.nonconj_params_to_unconstrained( nonconjugate_params, **kwargs) init_state = (0, opt_init(init_params), (np.inf, objective(init_params))) final_state = lax.while_loop(check_convergence, step, init_state) # Unpack the final state itr_count, params, lp = final_state[0], get_params( final_state[1]), -1 * final_state[2][1] if verbosity >= Verbosity.LOUD: print("Nesterov converged in ", itr_count, "iterations") return cls.nonconj_params_from_unconstrained(params, **kwargs), lp
def get_optimizer(name, optim_args, scheduler): name = name.lower() if optim_args and isinstance(optim_args, str): optim_args = [kv.split(':') for kv in optim_args.split(',')] optim_args = {k: float(v) for k, v in optim_args} optim_args = optim_args or {} if name == 'adam': init_fun, update_fun, get_params = optimizers.adam( scheduler, **optim_args) elif name == 'nesterov': if 'mass' not in optim_args: optim_args['mass'] = 0.1 init_fun, update_fun, get_params = optimizers.nesterov( scheduler, **optim_args) else: raise ValueError(f'An optimizer {name} is not supported. ') print(f'Loaded an optimization {name} - {optim_args}') return init_fun, update_fun, get_params
def get_optimizer(self, optim=None, stage='learn', step_size=None): if optim is None: if stage == 'learn': optim = self.optim_learn else: optim = self.optim_proj if step_size is None: step_size = self.step_size if optim == 1: if self.verb > 2: print("With momentum optimizer") opt_init, opt_update, get_params = momentum(step_size=step_size, mass=0.95) elif optim == 2: if self.verb > 2: print("With rmsprop optimizer") opt_init, opt_update, get_params = rmsprop(step_size, gamma=0.9, eps=1e-8) elif optim == 3: if self.verb > 2: print("With adagrad optimizer") opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9) elif optim == 4: if self.verb > 2: print("With Nesterov optimizer") opt_init, opt_update, get_params = nesterov(step_size, 0.9) elif optim == 5: if self.verb > 2: print("With SGD optimizer") opt_init, opt_update, get_params = sgd(step_size) else: if self.verb > 2: print("With adam optimizer") opt_init, opt_update, get_params = adam(step_size) return opt_init, opt_update, get_params
Relu, Dense(512), Relu, FanOut(2), stax.parallel(Dense(latent_dim), Dense(latent_dim)), ) decoder_init, decode = stax.serial( Dense(512), Relu, Dense(512), Relu, Dense(data.num_pixels) ) step_size = 1e-3 num_epochs = 100 batch_size = 128 opt_init, opt_update, get_params = optimizers.nesterov(step_size, mass=0.9) # Initialisation key = random.PRNGKey(0) enc_init_key, dec_init_key, key = random.split(key, 3) _, enc_init_params = encoder_init(enc_init_key, (batch_size, data.num_pixels)) _, dec_init_params = decoder_init(dec_init_key, (batch_size, latent_dim)) init_params = (enc_init_params, dec_init_params) opt_state = opt_init(init_params) @jit def update(i, key, opt_state, images): loss = lambda p: -elbo(key, p, images)[0] / len(images) g = grad(loss)(get_params(opt_state)) return opt_update(i, g, opt_state)
def __init__(self, learning_rate, mass=0.9): super().__init__(learning_rate) self.mass = mass self.opt_init, self.opt_update, self.get_params = nesterov( step_size=self.lr, mass=self.mass)