class ExperimentalOptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # experimental/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
def get_model_and_optimizer(model_params, hyper_params, rng): """ Trains the model for one batch. Args: model_params (list): Contains layer and activation weights for the forward function. hyper_params (dict): Hyperparameters of the model. rng (jax.random.PRNGKey): Random key for initializing parameter weights Returns: psi (func): forward function opt_update (func): Function to update the state. opt_state (func): Function get_params_from_opt, init_params (list): Start parameters. opt_init() """ init_fun, psi = serial_InvNet(model_params, hyper_params) # load model from serial _, init_params = init_fun(rng, (hyper_params['batch_size'], hyper_params['z_latent'])) # get initial params opt_init, opt_update, get_params_from_opt = adam( hyper_params['lr']) # get optimizer opt_state = opt_init(init_params) return psi, opt_update, opt_state, get_params_from_opt, init_params, opt_init
def test_beta_bernoulli(auto_class, rtol): data = np.array([1.0] * 8 + [0.0] * 2) def model(data): f = sample('beta', dist.Beta(1., 1.)) sample('obs', dist.Bernoulli(f), obs=data) opt_init, opt_update, get_params = optimizers.adam(0.08) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = auto_class(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, ), guide_args=(data, )) return opt_state_, rng_ opt_state, _ = lax.fori_loop(0, 300, body_fn, (opt_state, rng_train)) median = guide.median(opt_state) assert_allclose(median['beta'], 0.8, rtol=rtol)
def train(rng: jnp.ndarray, bij_params: Sequence[jnp.ndarray], deq_params: Sequence[jnp.ndarray], num_steps: int, lr: float, num_samples: int) -> Tuple: """Train the ambient flow with the combined loss function. Args: rng: Pseudo-random number generator seed. bij_params: List of arrays parameterizing the RealNVP bijectors. bij_fns: List of functions that compute the shift and scale of the RealNVP affine transformation. deq_params: Parameters of the mean and scale functions used in the log-normal dequantizer. deq_fn: Function that computes the mean and scale of the dequantization distribution. num_steps: Number of gradient descent iterations. lr: Gradient descent learning rate. num_samples: Number of dequantization samples. Returns: out: A tuple containing the estimated parameters of the ambient flow density and the dequantization distribution. The other element is the trace of the loss function. """ opt_init, opt_update, get_params = optimizers.adam(lr) def step(opt_state, it): step_rng = random.fold_in(rng, it) bij_params, deq_params = get_params(opt_state) loss_val, loss_grad = value_and_grad(loss, (1, 3))(step_rng, bij_params, bij_fns, deq_params, deq_fn, num_samples) loss_grad = tree_util.tree_map(partial(put.clip_and_zero_nans, clip_value=1.), loss_grad) opt_state = opt_update(it, loss_grad, opt_state) return opt_state, loss_val opt_state, trace = lax.scan(step, opt_init((bij_params, deq_params)), jnp.arange(num_steps)) bij_params, deq_params = get_params(opt_state) return (bij_params, deq_params), trace
def test_beta_bernoulli(auto_class): data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T def model(data): f = sample('beta', dist.Beta(np.ones(2), np.ones(2))) sample('obs', dist.Bernoulli(f), obs=data) opt_init, opt_update, get_params = optimizers.adam(0.01) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = auto_class(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, ), guide_args=(data, )) return opt_state_, rng_ opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train)) true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2) # test .sample_posterior method posterior_samples = guide.sample_posterior(random.PRNGKey(1), opt_state, sample_shape=(1000, )) assert_allclose(np.mean(posterior_samples['beta'], 0), true_coefs, atol=0.04)
def main(args): # Generate some data. data = random.normal(PRNGKey(0), shape=(100,)) + 3.0 # Construct an SVI object so we can do variational inference on our # model/guide pair. opt_init, opt_update, get_params = optimizers.adam(args.learning_rate) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) rng = PRNGKey(0) opt_state = svi_init(rng, model_args=(data,)) # Training loop rng, = random.split(rng, 1) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, opt_state_, rng_, model_args=(data,)) return opt_state_, rng_ opt_state, _ = lax.fori_loop(0, args.num_steps, body_fn, (opt_state, rng)) # Report the final values of the variational parameters # in the guide after training. params = get_params(opt_state) for name, value in params.items(): print("{} = {}".format(name, value)) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert np.abs(params["guide_loc"] - 3.0) < 0.1
def get_initial_state(system, rng, generate_x_obs_seq_init, dim_q, tol, adam_step_size=2e-1, reg_coeff=5e-2, coarse_tol=1e-1, max_iters=1000, max_num_tries=10): """Find an initial constraint satisying state. Uses a heuristic combination of gradient-based minimisation of the norm of a modified constraint function plus a subsequent projection step using a quasi-Newton method, to try to find an initial point `q` such that `max(abs(constr(q)) < tol`. """ # Use optimizers to set optimizer initialization and update functions opt_init, opt_update, get_params = opt.adam(adam_step_size) # Define a compiled update step @api.jit def step(i, opt_state, x_obs_seq_init): q, = get_params(opt_state) (obj, constr), grad = system.value_and_grad_init_objective( q, x_obs_seq_init, reg_coeff) opt_state = opt_update(i, grad, opt_state) return opt_state, obj, constr for t in range(max_num_tries): logging.info(f'Starting try {t+1}') q_init = rng.standard_normal(dim_q) x_obs_seq_init = generate_x_obs_seq_init(rng) opt_state = opt_init((q_init, )) for i in range(max_iters): opt_state_next, norm, constr = step(i, opt_state, x_obs_seq_init) if not np.isfinite(norm): logger.info('Adam iteration diverged') break max_abs_constr = maximum_norm(constr) if max_abs_constr < coarse_tol: logging.info('Within coarse_tol attempting projection.') q_init, = get_params(opt_state) state = ConditionedDiffusionHamiltonianState( q_init, x_obs_seq=x_obs_seq_init) try: state = jitted_solve_projection_onto_manifold_quasi_newton( state, state, 1., system, tol) except ConvergenceError: logger.info('Quasi-Newton iteration diverged.') if np.max(np.abs(system.constr(state))) < tol: logging.info('Found constraint satisfying state.') state.mom = system.sample_momentum(state, rng) return state if i % 100 == 0: logging.info(f'Iteration {i: >6}: mean|constr|^2 = {norm:.3e} ' f'max|constr| = {max_abs_constr:.3e}') opt_state = opt_state_next raise RuntimeError(f'Did not find valid state in {max_num_tries} tries.')
def train(params, data, gradients, epochs = 100, batch_size = 10, lr = 1e-3, shuffle = True, start = 0, lr_decay = 0, lr_decay_steps = 1, loggers = [], log_every = 10): if lr_decay: schedule = opt.inverse_time_decay(lr, lr_decay_steps, lr_decay, staircase=False) else: schedule = opt.constant(lr) opt_init, opt_update, get_params = opt.adam(schedule) update = get_opt_update(get_params, opt_update, gradients) opt_state = opt_init(params) for i in range(start, epochs): # TODO: shuffle in-place to reduce memory allocations (first, copy data) data, _ = shuffle_perm(data) if shuffle else (data, None) batches = batchify(data, batch_size) for data_batch in batches: opt_state = update(i, opt_state, data_batch) if i % log_every == 0: params = get_params(opt_state) logs = [log(params, i) for log in loggers] print(f"Epoch {i}", end=" ") for log in logs: print(f"[{log[0]} {float_log(log[1])}]", end=" ") print() return 0
def train_model(lr, iters, train_data, test_data, name='', plot_groups=None): opt_init, opt_update, get_params = optimizers.adam(lr) opt_update = jit(opt_update) _, params = init_fn(rand_key, (-1, train_data[0].shape[-1])) opt_state = opt_init(params) train_psnrs = [] test_psnrs = [] xs = [] if plot_groups is not None: plot_groups['Test PSNR'].append(f'{name}_test') plot_groups['Train PSNR'].append(f'{name}_train') for i in tqdm(range(iters), desc='train iter', leave=False): opt_state = opt_update(i, model_grad_loss(get_params(opt_state), *train_data), opt_state) if i % 25 == 0: train_psnr = model_psnr(get_params(opt_state), *train_data) test_psnr = model_psnr(get_params(opt_state), *test_data) train_psnrs.append(train_psnr) test_psnrs.append(test_psnr) xs.append(i) if plot_groups is not None: plotlosses_model.update({f'{name}_train':train_psnr, f'{name}_test':test_psnr}, current_step=i) if i % 100 == 0 and i != 0 and plot_groups is not None: plotlosses_model.send() if plot_groups is not None: plotlosses_model.send() results = { 'state': get_params(opt_state), 'train_psnrs': train_psnrs, 'test_psnrs': test_psnrs, 'xs': xs } return results
def test_svi(): def model(key): n1 = func.sample('n1', dist.normal(jnp.array(10.), jnp.array(10.)), key) return n1 def q(params, key): n1 = func.sample('n1', dist.normal(params['n1_mean'], params['n1_std']), key) return {'n1': n1} optimizer = jax_optim.adam(0.05) svi = func.svi(model, q, func.elbo, optimizer, initial_params={ 'n1_mean': jnp.array(0.), 'n1_std': jnp.array(1.) }) keys = jax.random.split(jax.random.PRNGKey(123), 500) for i in range(500): loss = svi(keys[i]) if i % 100 == 0: print(f"Step {i}: {loss}") inferred_n1_mean = svi.get_param('n1_mean') inferred_n1_std = svi.get_param('n1_std') tu.check_close(inferred_n1_mean, 10.611818) tu.check_close(inferred_n1_std, 9.024648)
def load_model(url): train_state = deepx.optimise.TrainState.restore(url) hparams = train_state.hparams model = deepx.resnet.ResNet(hparams.hidden_channels, 1, hparams.depth) opt = optimizers.adam(0.001) params = opt.params_fn(train_state.opt_state) return model, hparams, params
def main(): net_init, net_apply = stax.serial( stax.Dense(128), stax.Softplus, stax.Dense(128), stax.Softplus, stax.Dense(2), ) opt_init, opt_update, get_params = optimizers.adam(1e-3) out_shape, net_params = net_init(jax.random.PRNGKey(seed=42), input_shape=(-1, 2)) opt_state = opt_init(net_params) loss_history = [] print("Training...") train_step = get_train_step(opt_update, get_params, net_apply) for i in range(2000): x = sample_batch(size=128) loss, opt_state = train_step(i, opt_state, x) loss_history.append(loss.item()) print("Training Finished...") plot_gradients(loss_history, opt_state, get_params, net_params, net_apply)
def fit(vae, rng_key, data, data_vari, step_size=1e-3, max_iter=1000): ''' Args: *data: array like (obs, features) ''' start_params = vae.params opt_init, update_params, get_params = adam(step_size) opt_state = opt_init(start_params) history = [] min_loss_params = (1e10, None) for i in trange(max_iter, smoothing=0): params = get_params(opt_state) rng_key, subkey = random.split(rng_key) loss, grads = value_and_grad(objective)(params, vae, subkey, data, data_vari) opt_state = update_params(i, grads, opt_state) if loss < min_loss_params[0]: min_loss_params = (loss, params) history.append(float(loss)) vae = VAE(partial(vae.raw_encode, min_loss_params[1][0]), vae.raw_encode, partial(vae.raw_decode, min_loss_params[1][1]), vae.raw_decode, None, None, min_loss_params[1], vae.n_latent_dims, data_vari) return vae._replace(generate=partial(generate_samples, vae), fit=partial(fit, vae)), history
def fit(params: Dict, sequences: List[str], n: int, step_size: float = 0.001) -> Dict: """ Return weights fitted to predict the next letter in each sequence. The training loop is as follows. Per step in the training loop, we loop over each "length batch" of sequences and tune weights in order of the length of each sequence. For example, if we have sequences of length 302, 305, and 309, over K training epochs, we will perform 3xK updates, one step of updates for each length. To get batching of sequences by length done, we call on ``batch_sequences`` from our ``utils.py`` module, which returns a list of sub-lists, in which each sub-list contains the indices in the original list of sequences that are of a particular length. :param params: mLSTM1900 parameters. :param sequences: List of sequences to evotune on. :param n: The number of iterations to evotune on. """ xs, ys = length_batch_input_outputs(sequences) init, update, get_params = adam(step_size=step_size) # optimizer_funcs = jit(update), jit(get_params) @jit def step(i, state): """ Perform one step of evolutionary updating. This function is closed inside `fit` because we need access to the variables in its scope, particularly the update and get_params functions. By structuring the function this way, we can JIT-compile it, and thus gain a massive speed-up! :param i: The current iteration of the training loop. :param state: Current state of parameters from jax. """ params = get_params(state) g = grad(evotune_loss)(params, x, y) state = update(i, g, state) return state state = init(params) from time import time for i in range(n): for x, y in zip(xs, ys): state = step(i, state) params = get_params(state) return get_params(state)
def fit(self, X): opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3) opt_state = opt_init((self.encoder_params, self.decoder_params)) def loss(params, inputs): encoder_params, decoder_params = params enc = self.encoder_apply(encoder_params, X) dec = self.decoder_apply(decoder_params, X) return np.square(inputs - dec).sum() + 1e-3 * np.abs(params).sum() @jit def step(i, opt_state, inputs): params = get_params(opt_state) gradient = grad(loss)(params, inputs) return opt_update(i, gradient, opt_state) print('Training autoencoder...') batch_size, itercount = 32, itertools.count() key = random.PRNGKey(0) for epoch in range(5): temp_key, key = random.split(key) X = random.permutation(temp_key, X) for batch_index in range(0, X.shape[0], batch_size): opt_state = step(next(itercount), opt_state, X[batch_index:batch_index + batch_size]) self.encoder_params, self.decoder_params = get_params(opt_state)
def test_uniform_normal(): true_coef = 0.9 data = true_coef + random.normal(random.PRNGKey(0), (1000, )) def model(data): alpha = sample('alpha', dist.Uniform(0, 1)) loc = sample('loc', dist.Uniform(0, alpha)) sample('obs', dist.Normal(loc, 0.1), obs=data) opt_init, opt_update, get_params = optimizers.adam(0.01) rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) guide = AutoDiagonalNormal(rng_guide, model, get_params) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ), guide_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, ), guide_args=(data, )) return opt_state_, rng_ opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train)) median = guide.median(opt_state) assert_allclose(median['loc'], true_coef, rtol=0.05) # test .quantile method median = guide.quantiles(opt_state, [0.2, 0.5]) assert_allclose(median['loc'][1], true_coef, rtol=0.1)
def test_beta_bernoulli(): data = np.array([1.0] * 8 + [0.0] * 2) def model(data): f = sample("beta", dist.Beta(1., 1.)) sample("obs", dist.Bernoulli(f), obs=data) def guide(): alpha_q = param("alpha_q", 1.0, constraint=constraints.positive) beta_q = param("beta_q", 1.0, constraint=constraints.positive) sample("beta", dist.Beta(alpha_q, beta_q)) opt_init, opt_update, get_params = optimizers.adam(0.05) svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params) rng_init, rng_train = random.split(random.PRNGKey(1)) opt_state, constrain_fn = svi_init(rng_init, model_args=(data, )) def body_fn(i, val): opt_state_, rng_ = val loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_, model_args=(data, )) return opt_state_, rng_ opt_state, _ = fori_loop(0, 300, body_fn, (opt_state, rng_train)) params = constrain_fn(get_params(opt_state)) assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']), 0.8, atol=0.05, rtol=0.05)
def train(rng, params, predict, X, y): """Generic train function called for each slice. Responsible for, given an rng key, a set of parameters to be trained, some inputs X and some outputs y, finetuning the params on X and y according to some internally defined training configuration. """ iterations = 65 batch_size = 32 step_size = 0.01 @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, predict, batch), opt_state) opt_init, opt_update, get_params = optimizers.adam(step_size) opt_state = opt_init(params) temp, rng = random.split(rng) batches = data_stream(rng, batch_size, X, y) for i in range(iterations): temp, rng = random.split(rng) opt_state = update(temp, i, opt_state, next(batches)) return get_params(opt_state)
def args_to_op(optimizer_string, lr, mom=0.9, var=0.999, eps=1e-7): return { "gd": lambda lr, *unused: op.sgd(lr), "sgd": lambda lr, *unused: op.sgd(lr), "momentum": lambda lr, mom, *unused: op.momentum(lr, mom), "adam": lambda lr, mom, var, eps: op.adam(lr, mom, var, eps), }[optimizer_string.lower()](lr, mom, var, eps)
def fit(self, X, y, batch_size=5, n_iter=10000, lr=0.001, lr_type='constant'): X = np.array(X).astype(np.float32) y = np.array(y).reshape(-1, 1).astype(np.float32) m, n = X.shape opt_init, opt_update, get_params = optimizers.adam(lr) opt_state = opt_init(self.params) epochs = ceil(n_iter / floor(m / batch_size)) for i in range(epochs): for j in range(0, m, batch_size): X_batch = X[j:j + batch_size] y_batch = y[j:j + batch_size] cur_loss, grads = value_and_grad(self.mse_loss)(self.params, X_batch, y_batch) opt_state = opt_update(0, grads, opt_state) self.params = get_params(opt_state) # print(cur_loss) print("Epoch: ", i, end=" ") print("cost: ", self.mse_loss(self.params, X, y))
def main(): dataset = tfds.load('mnist', shuffle_files=True) tr, ts = iter(dataset['train'].batch(128).prefetch(1).repeat()), iter(dataset['test'].batch(1).repeat()) # print("{} {}".format(tr.__next__()['image'].shape, tr.__next__()['label'].shape)) ## build the container for keep the weights weights_container = [] ## building the CNN # 建立網路的時候,在subroutine盡量不要放if。if有可能造成計算圖的斷裂。 # 因此weight initialization跟forward兩個function盡量分開。這樣也同時可以使用到jit加速。 feature_map_nos = [128, 256, 512, 10] input_shape = [1, 28, 28, 1] jax_fcn_init(weights_container, input_shape,feature_map_nos, 3) print('weights initialized ...') # test the forwarding # print(jax_fcn(jax.numpy.ones(input_shape, dtype=np.float32), weights_container)) # exit() ##### creating the optimizer ##### # 1) jax.experimental.optimizer必須要匯入到主命名空間中,所以需要使用到import ... from # 2) optimizer物件在呼叫以後會返回三個不同的方法: # 2.1) opt_init : 用來初始化optimizer用的方法。這邊需要告訴optimizer要最佳化那些變數。 # 因此會出現類似opt_init(param)的使用方式。param就是要用這個optimizer做最佳化的參數。 # opt_init會依照param建立在optimizer中相對應的參數,例如momentum或是目前的beta之類的。 # 這時就須要把這些狀態存起來讓下次update的時候使用這些狀態參數。 # 2.2) opt_update : 把梯度跟optimizer的狀態輸入後進行weights的update。optimizer的狀態就 # 包括momentum這些之前記憶下來的數字。 # 2.3) opt_param: 回傳會被最佳化的weights ## Create the optimizer, and get the essential objects learning_rate = 1e-4 opt_init, opt_update, opt_get_params = optimizers.adam(learning_rate) ## initializing the optimizer, and the the status objects after initialing it opt_status = opt_init(weights_container) # print(opt_get_params(opt_status)) # exit() ##### training loop ##### for training_step in range(5000): tr_c = next(tr) image, label = tr_c['image'].numpy(), tr_c['label'].numpy() image, label = image.astype('float32'), label.astype('float32') ## computing the loss and gradients current_loss, gradients = jax.value_and_grad(loss)(weights_container, image, one_hot(label)) opt_status = opt_update(training_step, gradients, opt_status) weights_container = opt_get_params(opt_status) # print(gradients.shape) # print(opt_get_params(opt_status)) # exit() print('step {} loss:{} accuracy:{}'.format(training_step, current_loss, accuracy(image, weights_container, label))) pass
def initialize_optimizers(policy_net_params, value_net_params): """Initialize optimizers for the policy and value params.""" # ppo_opt_init, ppo_opt_update = optimizers.sgd(step_size=1e-3) # val_opt_init, val_opt_update = optimizers.sgd(step_size=1e-3) ppo_opt_init, ppo_opt_update = optimizers.adam(step_size=1e-3, b1=0.9, b2=0.999, eps=1e-08) value_opt_init, value_opt_update = optimizers.adam(step_size=1e-3, b1=0.9, b2=0.999, eps=1e-08) ppo_opt_state = ppo_opt_init(policy_net_params) value_opt_state = value_opt_init(value_net_params) return (ppo_opt_state, ppo_opt_update), (value_opt_state, value_opt_update)
def test_optim_multi_params(): params = {'x': np.array([1., 1., 1.]), 'y': np.array([-1, -1., -1.])} opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2) opt_state = opt_init(params) for i in range(1000): opt_state = step(i, opt_state, opt_update, get_params) for _, param in get_params(opt_state).items(): assert np.allclose(param, np.zeros(3))
def objective_funcs(func, x0): init, update, get_params = optimizers.adam(1e-2) def myupdate(i, state): params = get_params(state) return update(i, jax.grad(func)(params), state) return init(x0), myupdate, get_params
def __init__(self, activation=Relu): self.activation = activation self.net_init, self.net_apply = network(activation) self.optimizer = optimizers.adam(step_size=1e-3) rng = random.PRNGKey(0) in_shape = (-1, 1,) self.out_shape, self.net_params = self.net_init(rng, in_shape)
def minimize_ADAM(fun, x0, data, step_size=0.01, f_tol=1e-5, n_iter_max=1000): """Minimize function with ADAM. Args: fun: function to minimize, which takes (x,data) as input x0: ndarray: initial solution data: ndarray: full data to subsample #Optimizer hyperparams step_size: positive scalar for step-size to pass into minimize_adam f_tol: positive scalar for value of norm of f change before termination n_iter_max: maximum number of sgd steps before termination Returns: x_opt: Optimized x loss: Minimized loss value n_iter: Number of iterations at termination delta_f: Value of change of f at termination """ #function wrappers value_and_grad_fun = jit(value_and_grad(fun)) @jit #wrapper around step to allow for termination checks def step(carry): i, loss, g, opt_state, key, n = carry key, *subkey = random.split(key) ind = random.shuffle(key, jnp.arange(n)) params = get_params(opt_state) loss, g = value_and_grad_fun(params, data[ind]) i = i + 1 carry = i, loss, g, opt_update(i, g, opt_state) return carry @jit #termination condition on gradient norm or number of iterations def converged(carry): i, loss, g, opt_state, key, n = carry delta_f = 1 return jnp.logical_and(norm_g > g_tol, i < n_iter_max) #check dimensions and initialize d = jnp.shape(x0)[0] key = random.PRNGKey(0) n = jnp.shape(data)[0] #initialize optimzer opt_init, opt_update, get_params = adam(step_size=step_size) opt_state = opt_init(x0) #run optimizer until termination carry = 0, 1., jnp.ones(d), opt_state, key, n carry = while_loop(converged, step, carry) #extract values from carry n_iter, loss, g, opt_state, _, _ = carry x_opt = get_params(opt_state) return x_opt, loss, n_iter, delta_f
def optimiseCrossover(highFileName='data/[email protected]', highDriverName='AN25', lowFileName='data/[email protected]', lowDriverName='TCP115', dataDir='data', plotName='opt', learningRate=1E-2, epochs=25): crossover = setupDriverCrossover(highFileName, highDriverName, lowFileName, lowDriverName) flatGrad = jax.grad(crossover.flatness, argnums=[0, 1, 2]) init_fun, update_fun, get_params = adam(learningRate) res = np.log(5E6) cap = np.log(2E-12) highRes = np.log(1) state = init_fun((res, cap, highRes)) losses = [] for i in tqdm(range(epochs)): grads = flatGrad(res, cap, highRes) state = update_fun(i, grads, state) res, cap, highRes = get_params(state) flatness = crossover.flatness(res, cap, highRes) losses.append(flatness) availableComponents = AvailableComponents(f'{dataDir}/resistors.json', f'{dataDir}/capacitors.json') nearestRes = availableComponents.nearestRes(np.exp(res)) nearestHighRes = availableComponents.nearestRes(np.exp(highRes)) nearestCap = availableComponents.nearestCap(np.exp(cap)) co, lo, hi = crossover.applyCrossover(nearestRes[1], nearestCap[1], nearestHighRes[1]) plt.plot(crossover.frequencies, co, label='Total') plt.plot(crossover.frequencies, hi, label='High') plt.plot(crossover.frequencies, lo, label='Low') plt.xscale('log') plt.legend(loc=0, fontsize=18) plt.xlabel('Frequency [Hz]', fontsize=16) plt.ylabel('Sound pressure level [dB]', fontsize=16) plt.savefig(f'{plotName}.pdf') plt.clf()
def __init__(self): super(Policy, self).__init__() self.data = [] self.net = hk.Sequential( [hk.Flatten(), hk.Linear(128), jax.nn.relu, hk.Linear(2), jax.nn.softmax]) self.optimizer = adam(step_size=learning_rate)
def get_model_and_optimizer(model_params, hyper_params, rng): init_fun, psi, g = serial(model_params) # load model from serial _, init_params = init_fun( rng, (hyper_params['batch_size'], hyper_params['z_latent'] + 4)) # get initial params opt_init, opt_update, get_params_from_opt = adam( hyper_params['lr']) # get optimizer opt_state = opt_init(init_params) return psi, g, opt_update, opt_state, get_params_from_opt, init_params, opt_init
def optimize_params(self, initial_params, num_iters, step_size, tolerance, verbal): opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) opt_state = opt_init(initial_params) @jit def step(i, opt_state): p = get_params(opt_state) g = grad(self.cost)(p) return opt_update(i, g, opt_state) cost_list = [] params_list = [] if verbal: print('{0}\t{1}\t'.format('Iter', 'Cost')) for i in range(num_iters): opt_state = step(i, opt_state) params_list.append(get_params(opt_state)) cost_list.append(self.cost(params_list[-1])) if verbal: if i % int(verbal) == 0: print('{0}\t{1:.3f}\t'.format(i, cost_list[-1])) if len(params_list) > tolerance: if np.all( (np.array(cost_list[1:])) - np.array(cost_list[:-1]) > 0): params = params_list[0] if verbal: print( 'Stop at {} steps: cost has been monotonically increasing for {} steps.' .format(i, tolerance)) break elif np.all( np.array(cost_list[:-1]) - np.array(cost_list[1:]) < 1e-5): params = params_list[-1] if verbal: print( 'Stop at {} steps: cost has been changing less than 1e-5 for {} steps.' .format(i, tolerance)) break else: params_list.pop(0) cost_list.pop(0) else: params = params_list[-1] if verbal: print('Stop: reached {} steps, final cost={}.'.format( num_iters, cost_list[-1])) return params