def main(args): rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, mu_true = create_toy_data(toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( (X_train, ), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test, ), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=args.sigma, clipping_threshold=args.clip_threshold, d=args.dimensions, num_obs_total=args.num_samples) rng, svi_init_rng, batchifier_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=batchifier_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) q = args.batch_size / args.num_samples eps = svi.get_epsilon(args.delta, q, num_epochs=args.num_epochs) print("Privacy epsilon {} (for sigma: {}, delta: {}, C: {}, q: {})".format( eps, args.sigma, args.clip_threshold, args.delta, q)) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update(svi_state, *batch) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10) == 0): rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss = eval_test(svi_state, test_batchifier_state, num_test_batches) print( "Epoch {}: loss = {} (on training set: {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start)) params = svi.get_params(svi_state) mu_loc = params['mu_loc'] mu_std = jnp.exp(params['mu_std_log']) print("### expected: {}".format(mu_true)) print("### svi result\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = analytical_solution(X_train) print("### analytical solution\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std)) mu_loc, mu_std = ml_estimate(X_train) print("### ml estimate\nmu_loc: {}\nerror: {}\nmu_std: {}".format( mu_loc, jnp.linalg.norm(mu_loc - mu_true), mu_std))
def dz(self, samples, noise_scale=0.4, **args): '''Daily deaths with observation noise''' dz_mean = self.dz_mean(samples, **args) dz = dist.Normal(dz_mean, noise_scale * dz_mean).sample(PRNGKey(10)) return dz
def on_slider_update(change): global lr, momentum if change["owner"].description == "log(lr)": lr = 10**change["new"] # change["owner"].description = "LR : " + str(round(lr, 3)) elif change["owner"].description == "momentum": momentum = change["new"] descend_and_update(nn_loss_xy) X, Y, X_test = get_data(seed=5) key = PRNGKey(0) x_tr, x_te = split(X, key) x_tr = x_tr[:, [1]] x_te = x_te[:, [1]] y_tr, y_te = split(Y, key) n_layers = 3 n_neurons = 3 nn_init_fn, nn_apply_fn = stax.serial( *chain(*[(Tanh, Dense(n_neurons)) for _ in range(n_layers)]), Dense(1), ) out_shape, init_params = nn_init_fn(PRNGKey(9), x_tr.shape[1:])
# *** MLP configuration *** n_hidden = 6 n_in, n_out = 1, 1 n_params = (n_in + 1) * n_hidden + (n_hidden + 1) * n_out fwd_mlp = partial(mlp, n_hidden=n_hidden) # vectorised for multiple observations fwd_mlp_obs = jax.vmap(fwd_mlp, in_axes=[None, 0]) # vectorised for multiple weights fwd_mlp_weights = jax.vmap(fwd_mlp, in_axes=[1, None]) # vectorised for multiple observations and weights fwd_mlp_obs_weights = jax.vmap(fwd_mlp_obs, in_axes=[0, None]) # *** Generating training and test data *** n_obs = 200 key = PRNGKey(314) key_sample_obs, key_weights = split(key, 2) xmin, xmax = -3, 3 sigma_y = 3.0 x, y = sample_observations(key_sample_obs, f, n_obs, xmin, xmax, x_noise=0, y_noise=sigma_y) xtest = jnp.linspace(x.min(), x.max(), n_obs) # *** MLP Training with EKF *** W0 = normal(key_weights, (n_params,)) * 1 # initial random guess Q = jnp.eye(n_params) * 1e-4; # parameters do not change R = jnp.eye(1) * sigma_y**2; # observation noise is fixed Vinit = jnp.eye(n_params) * 100 # vague prior ekf = ds.ExtendedKalmanFilter(fz, fwd_mlp, Q, R) ekf_mu_hist, ekf_Sigma_hist = ekf.filter(W0, y[:, None], x[:, None], Vinit)
def body(v): t, rmse, clusters = v jac_fn = jacrev(partial(cost_sp, features)) hes_fn = jacfwd(jac_fn) new_cluster = clusters - jac_fn(clusters) / hes_fn(clusters).sum( (0, 1)) rmse = ((new_cluster - clusters)**2).sum() return t + 1, rmse, new_cluster t, rmse, clusters = while_loop(cond, sparsify(body), (0, float("inf"), clusters)) return clusters if __name__ == '__main__': data_key, sparse_key = split(PRNGKey(8)) num_datapoints = 100 num_features = 7 sparsity = .5 num_clusters = 5 max_iter = 10 features = normal(data_key, (num_datapoints, num_features)) features = features * bernoulli(sparse_key, sparsity, (num_datapoints, num_features)) clusters = features[:num_clusters] features = sparse.BCOO.fromdense(features) new_cluster = kmeans(max_iter, clusters, features) print(new_cluster)
@jit def accuracy(params, batch): logits = predict(params, batch["X"]) # logits = log_softmax(logits) return jnp.mean(jnp.argmax(logits, -1) == batch["y"]) @jit def logprior(params): # Spherical Gaussian prior leaves_of_params = tree_leaves(params) return sum(tree_map(lambda p: jnp.sum(jax.scipy.stats.norm.logpdf(p, scale=l2_regularizer)), leaves_of_params)) key = PRNGKey(42) data_key, init_key, opt_key, sample_key, warmstart_key = split(key, 5) n_train, n_test = 20000, 1000 train_ds, test_ds = load_mnist(data_key, n_train, n_test) data = (train_ds["X"], train_ds["y"]) n_features = train_ds["X"].shape[1] n_classes = 10 # model init_random_params, predict = stax.serial( Dense(200), Relu, Dense(50), Relu, Dense(n_classes), LogSoftmax) _, params_init_tree = init_random_params(init_key, input_shape=(-1, n_features))
def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2*math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup_steps: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param bool heuristic_step_size: If ``True``, a coarse grained adjustment of step size is done at the beginning of each adaptation window to achieve `target_acceptance_prob`. :param jax.random.PRNGKey rng: random key to be used as the source of randomness. """ step_size = float(step_size) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth trajectory_len = float(trajectory_length) max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter(num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_hmc, rng_wa = random.split(rng) wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng) vv_state = vv_init(z, r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0., wa_state.step_size, wa_state.inverse_mass_matrix, wa_state.mass_matrix_sqrt, rng_hmc) wa_update = jit(wa_update) if run_warmup: # JIT if progress bar updates not required if not progbar: hmc_state, _ = jit(fori_loop, static_argnums=(2,))(0, num_warmup, warmup_update, (hmc_state, wa_state)) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state)) # TODO: set refresh=True when its performance issue is resolved t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) # Reset `i` and `mean_accept_prob` for fresh diagnostics. hmc_state.update(i=0, mean_accept_prob=0) return hmc_state else: return hmc_state, wa_state, warmup_update
def fit(observations, lens, num_hidden, num_obs, batch_size, optimizer, rng_key=None, num_epochs=1): ''' Trains the HMM model with the given number of hidden states and observations via any optimizer. Parameters ---------- observations: array(N, seq_len) All observation sequences lens : array(N, seq_len) Consists of the valid length of each observation sequence num_hidden : int The number of hidden state num_obs : int The number of observable events batch_size : int The number of observation sequences that will be included in each minibatch optimizer : jax.experimental.optimizers.Optimizer Optimizer that is used during training num_epochs : int The total number of iterations Returns ------- * HMMJax Hidden Markov Model * array Consists of training losses ''' global opt_init, opt_update, get_params if rng_key is None: rng_key = PRNGKey(0) rng_init, rng_iter = split(rng_key) params = init_random_params([num_hidden, num_obs], rng_init) opt_init, opt_update, get_params = optimizer opt_state = opt_init(params) itercount = itertools.count() def epoch_step(opt_state, key): def train_step(opt_state, params): batch, length = params opt_state, loss = update(next(itercount), opt_state, batch, length) return opt_state, loss batches, valid_lens = hmm_sample_minibatches(observations, lens, batch_size, key) params = (batches, valid_lens) opt_state, losses = jax.lax.scan(train_step, opt_state, params) return opt_state, losses.mean() epochs = split(rng_iter, num_epochs) opt_state, losses = jax.lax.scan(epoch_step, opt_state, epochs) losses = losses.flatten() params = get_params(opt_state) params = HMMJax(softmax(params.trans_mat, axis=1), softmax(params.obs_mat, axis=1), softmax(params.init_dist)) return params, losses
gmm = GMM(pi, mu, Sigma) gmm.fit_em(X, num_of_iters=5) n_success_ml += 1 except Exception as E: print(str(E)) try: gmm = GMM(pi, mu, Sigma) gmm.fit_em(X, num_of_iters=5, S=S, eta=eta) n_success_map += 1 except Exception as E: print(str(E)) pct_ml = n_success_ml / n_attempts pct_map = n_success_map / n_attempts return [1-pct_ml, 1-pct_map] rng_key = PRNGKey(0) plt.rcParams["axes.spines.right"] = False plt.rcParams["axes.spines.top"] = False n_comps = 3 pi = jnp.ones((n_comps, )) / n_comps hist_ml, hist_map = [], [] test_dims = jnp.arange(10, 60, 10) keys = split(rng_key, 10) n_samples = 150 mu_base = jnp.array([[-1, 1], [1, -1], [3, -1]]) Sigma1_base = jnp.array([[1, -0.7], [-0.7, 1]]) Sigma2_base = jnp.array([[1, 0.7], [0.7, 1]])
S2 = jnp.array([[0.3, -0.5], [-0.5, 1.3]]) S3 = jnp.array([[0.8, 0.4], [0.4, 0.5]]) cov_collection = jnp.array([S1, S2, S3]) / 60 mu_collection = jnp.array([[0.3, 0.3], [0.8, 0.5], [0.3, 0.8]]) hmm = HMM( trans_dist=distrax.Categorical(probs=A), init_dist=distrax.Categorical(probs=initial_probs), obs_dist=distrax.as_distribution( tfp.substrates.jax.distributions.MultivariateNormalFullCovariance( loc=mu_collection, covariance_matrix=cov_collection))) n_samples, seed = 50, 100 samples_state, samples_obs = hmm_sample(hmm, n_samples, PRNGKey(seed)) xmin, xmax = 0, 1 ymin, ymax = 0, 1.2 colors = ["tab:green", "tab:blue", "tab:red"] fig, ax = plt.subplots() _, color_sample = plot_2dhmm(hmm, samples_obs, samples_state, colors, ax, xmin, xmax, ymin, ymax) pml.savefig("hmm_lillypad_2d.pdf") fig, ax = plt.subplots() ax.step(range(n_samples), samples_state, where="post", c="black",
def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=3): ''' Finds the parameters of Gaussian Mixture Model using gradient descent algorithm with the given hyperparameters. Parameters ---------- observations : array The observation sequences which Bernoulli Mixture Model is trained on batch_size : int The size of the batch rng_key : array Random key of shape (2,) and dtype uint32 optimizer : jax.experimental.optimizers.Optimizer Optimizer to be used num_epochs : int The number of epoch the training process takes place Returns ------- * array Mean loss values found per epoch * array Mixing coefficients found per epoch * array Means of Gaussian distribution found per epoch * array Covariances of Gaussian distribution found per epoch * array Responsibilites found per epoch ''' global opt_init, opt_update, get_params if rng_key is None: rng_key = PRNGKey(0) if optimizer is not None: opt_init, opt_update, get_params = optimizer opt_state = opt_init((softmax(self.mixing_coeffs), self.means, self.covariances)) itercount = itertools.count() def epoch_step(opt_state, key): def train_step(opt_state, batch): opt_state, loss = self.update(next(itercount), opt_state, batch) return opt_state, loss batches = self._make_minibatches(observations, batch_size, key) opt_state, losses = scan(train_step, opt_state, batches) params = get_params(opt_state) mixing_coeffs, means, untransormed_cov = params cov_matrix = vmap(self._transform_to_covariance_matrix)(untransormed_cov) self.model = (softmax(mixing_coeffs), means, cov_matrix) responsibilities = self.responsibilities(observations) return opt_state, (losses.mean(), *params, responsibilities) epochs = split(rng_key, num_epochs) opt_state, history = scan(epoch_step, opt_state, epochs) params = get_params(opt_state) mixing_coeffs, means, untransormed_cov = params cov_matrix = vmap(self._transform_to_covariance_matrix)(untransormed_cov) self.model = (softmax(mixing_coeffs), means, cov_matrix) return history
def test_pixelcnn(): loss, _ = PixelCNNPP(nr_filters=1, nr_resnet=1) images = jnp.zeros((2, 16, 16, 3), image_dtype) opt = optimizers.Adam() state = opt.init(loss.init_parameters(images, key=PRNGKey(0)))
def main(args): encoder_init, encode = encoder(args.hidden_dim, args.z_dim) decoder_init, decode = decoder(args.hidden_dim, 28 * 28) opt_init, opt_update = optimizers.adam(args.learning_rate) svi_init, svi_update, svi_eval = svi(model, guide, elbo, opt_init, opt_update, encode=encode, decode=decode, z_dim=args.z_dim) svi_update = jit(svi_update) rng = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() _, encoder_params = encoder_init((args.batch_size, 28 * 28)) _, decoder_params = decoder_init((args.batch_size, args.z_dim)) params = {'encoder': encoder_params, 'decoder': decoder_params} rng, sample_batch = binarize(rng, train_fetch(0, train_idx)[0]) opt_state = svi_init(rng, (sample_batch, ), (sample_batch, ), params) rng, = random.split(rng, 1) @jit def epoch_train(opt_state, rng): def body_fn(i, val): loss_sum, opt_state, rng = val rng, batch = binarize(rng, train_fetch(i, train_idx)[0]) loss, opt_state, rng = svi_update( i, opt_state, rng, (batch, ), (batch, ), ) loss_sum += loss return loss_sum, opt_state, rng return lax.fori_loop(0, num_train, body_fn, (0., opt_state, rng)) @jit def eval_test(opt_state, rng): def body_fun(i, val): loss_sum, rng = val rng, = random.split(rng, 1) rng, batch = binarize(rng, test_fetch(i, test_idx)[0]) loss = svi_eval(opt_state, rng, (batch, ), (batch, )) / len(batch) loss_sum += loss return loss_sum, rng loss, _ = lax.fori_loop(0, num_test, body_fun, (0., rng)) loss = loss / num_test return loss def reconstruct_img(epoch): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') _, test_sample = binarize(rng, img) params = optimizers.get_params(opt_state) z_mean, z_var = encode(params['encoder'], test_sample.reshape([1, -1])) z = dist.norm(z_mean, z_var).rvs(random_state=rng) img_loc = decode(params['decoder'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): t_start = time.time() num_train, train_idx = train_init() _, opt_state, rng = epoch_train(opt_state, rng) rng, rng_test = random.split(rng, 2) num_test, test_idx = test_init() test_loss = eval_test(opt_state, rng_test) reconstruct_img(i) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
def test_Conv1DTranspose_runs(channels, filter_shape, padding, strides, input_shape): conv = Conv1D(channels, filter_shape, strides=strides, padding=padding) inputs = random_inputs(input_shape) params = conv.init_parameters(PRNGKey(0), inputs) conv.apply(params, inputs)
def main(args): N = args.num_samples k = args.num_components d = args.dimensions rng = PRNGKey(1234) rng, toy_data_rng = jax.random.split(rng, 2) X_train, X_test, latent_vals = create_toy_data(toy_data_rng, N, d) train_init, train_fetch = subsample_batchify_data((X_train,), batch_size=args.batch_size) test_init, test_fetch = split_batchify_data((X_test,), batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) # note(lumip): fix the parameters in the models def fix_params(model_fn, k): def fixed_params_fn(obs, **kwargs): return model_fn(k, obs, **kwargs) return fixed_params_fn model_fixed = fix_params(model, k) guide_fixed = fix_params(guide, k) svi = DPSVI( model_fixed, guide_fixed, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples ) rng, svi_init_rng, fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(fetch_rng) batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *batch) @jit def epoch_train(svi_state, data_idx, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) svi_state, batch_loss = svi.update( svi_state, *batch ) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch): def body_fn(i, loss_sum): batch = test_fetch(i, batchifier_state) loss = svi.evaluate(svi_state, *batch) loss_sum += loss / (args.num_samples * num_batch) return loss_sum return lax.fori_loop(0, num_batch, body_fn, 0.) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init(rng_key=data_fetch_rng) svi_state, train_loss = epoch_train( svi_state, train_batchifier_state, num_train_batches ) train_loss.block_until_ready() t_end = time.time() if i % 100 == 0: rng, test_fetch_rng = random.split(rng, 2) num_test_batches, test_batchifier_state = test_init(rng_key=test_fetch_rng) test_loss = eval_test( svi_state, test_batchifier_state, num_test_batches ) print("Epoch {}: loss = {} (on training set = {}) ({:.2f} s.)".format( i, test_loss, train_loss, t_end - t_start )) params = svi.get_params(svi_state) print(params) posterior_modes = params['mus_loc'] posterior_pis = dist.Dirichlet(jnp.exp(params['alpha_log'])).mean print("MAP estimate of mixture weights: {}".format(posterior_pis)) print("MAP estimate of mixture modes : {}".format(posterior_modes)) acc = compute_assignment_accuracy( X_test, latent_vals[1], latent_vals[2], posterior_modes, posterior_pis ) print("assignment accuracy: {}".format(acc))
and the output (label) kets for given `params` """ fidel = 0 thetas, phis, omegas = params unitary = Unitary(N)(thetas, phis, omegas) for i in range(train_len): pred = jnp.dot(unitary, inputs[i]) step_fidel = fidelity(pred, outputs[i]) fidel += step_fidel return (fidel / train_len)[0][0] # Fixed PRNGKeys to pick the same starting params params = uniform(PRNGKey(0), (N**2, ), minval=0.0, maxval=2 * jnp.pi) thetas = params[:N * (N - 1) // 2] phis = params[N * (N - 1) // 2:N * (N - 1)] omegas = params[N * (N - 1):] params = [thetas, phis, omegas] opt_init, opt_update, get_params = optimizers.adam(step_size=1e-1) opt_state = opt_init(params) def step(i, opt_state, opt_update): params = get_params(opt_state) g = grad(cost)(params, ket_input, ket_output) return opt_update(i, g, opt_state)
sequences = ["HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HAS", "HASVASTA"] * 5 holdout_sequences = [ "HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HASVALTA", ] * 5 PROJECT_NAME = "evotuning_temp" init_fun, apply_fun = mlstm64() # The input_shape is always going to be (-1, 26), # because that is the number of unique AA, one-hot encoded. _, inital_params = init_fun(PRNGKey(42), input_shape=(-1, 26)) # 1. Evotuning with Optuna n_epochs_config = {"low": 1, "high": 1} lr_config = {"low": 1e-5, "high": 1e-3} study, evotuned_params = evotune( sequences=sequences, model_func=apply_fun, params=inital_params, out_dom_seqs=holdout_sequences, n_trials=2, n_splits=2, n_epochs_config=n_epochs_config, learning_rate_config=lr_config, )
def init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, rng_key=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type( step_size, xla_bridge.canonicalize_dtype(np.float64)) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa = random.split(rng_key) wa_state = wa_init(z, rng_key_wa, step_size, inverse_mass_matrix=inverse_mass_matrix, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key) vv_state = vv_init(z, r) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) return hmc_state
def from_seed(cls: Type[T], seed: int) -> T: return cls(PRNGKey(seed))
def parameters_from(self, reuse, *example_inputs): return self._init_parameters(*example_inputs, key=PRNGKey(0), reuse=reuse, reuse_only=True)
def get_batches(batches=100, sequence_length=1000, key=PRNGKey(0)): for _ in range(batches): key, batch_key = random.split(key) yield random.normal(batch_key, (1, receptive_field + sequence_length, 1))
def _example_outputs(self, *inputs): _, outputs = self._init_and_apply_parameters_dict(*inputs, key=PRNGKey(0)) return outputs
def fit_copula_jregression(y,x,n_perm = 10, seed = 20,n_perm_optim = None, single_bandwidth = True): #Set seed for scipy np.random.seed(seed) #Combine x,y z = jnp.concatenate((x,y.reshape(-1,1)), axis = 1) #Generate random permutations key = PRNGKey(seed) key,*subkey = split(key,n_perm +1 ) subkey = jnp.array(subkey) z_perm = vmap(permutation,(0,None))(subkey,z) #Initialize parameter and put on correct scale to lie in [0,1] d = jnp.shape(z)[1] if single_bandwidth == True: rho_init = 0.9*jnp.ones(1) else: rho_init = 0.9*jnp.ones(d) hyperparam_init = jnp.log(1/rho_init - 1) #calculate rho_opt #either use all permutations or a selected number to fit bandwidth if n_perm_optim is None: z_perm_opt = z_perm else: z_perm_opt = z_perm[0:n_perm_optim] #Compiling print('Compiling...') start = time.time() #Condit temp = mvcr.fun_jcll_perm_sp(hyperparam_init,z_perm_opt) temp = mvcr.grad_jcll_perm_sp(hyperparam_init,z_perm_opt) temp = mvcd.update_pn_loop_perm(rho_init,z_perm)[0].block_until_ready() end = time.time() print('Compilation time: {}s'.format(round(end-start, 3))) print('Optimizing...') start = time.time() # Condit preq loglik opt = minimize(fun = mvcr.fun_jcll_perm_sp, x0= hyperparam_init,\ args = (z_perm_opt),jac =mvcr.grad_jcll_perm_sp,method = 'SLSQP') #check optimization succeeded if opt.success == False: print('Optimization failed') #unscale hyperparameter hyperparam_opt = opt.x rho_opt = 1/(1+jnp.exp(hyperparam_opt)) end = time.time() print('Optimization time: {}s'.format(round(end-start, 3))) print('Fitting...') start = time.time() vn_perm= mvcd.update_pn_loop_perm(rho_opt,z_perm)[0].block_until_ready() end = time.time() print('Fit time: {}s'.format(round(end-start, 3))) copula_jregression_obj = namedtuple('copula_jregression_obj',['vn_perm','rho_opt','preq_loglik']) return copula_jregression_obj(vn_perm,rho_opt,-opt.fun)
def main(args): encoder_nn = encoder(args.hidden_dim, args.z_dim) decoder_nn = decoder(args.hidden_dim, 28 * 28) adam = optim.Adam(args.learning_rate) svi = SVI(model, guide, adam, Trace_ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim) rng_key = PRNGKey(0) train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='train') test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split='test') num_train, train_idx = train_init() rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3) sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0]) svi_state = svi.init(rng_key_init, sample_batch) @jit def epoch_train(svi_state, rng_key, train_idx): def body_fn(i, val): loss_sum, svi_state = val rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0]) svi_state, loss = svi.update(svi_state, batch) loss_sum += loss return loss_sum, svi_state return lax.fori_loop(0, num_train, body_fn, (0., svi_state)) @jit def eval_test(svi_state, rng_key, test_idx): def body_fun(i, loss_sum): rng_key_binarize = random.fold_in(rng_key, i) batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0]) # FIXME: does this lead to a requirement for an rng_key arg in svi_eval? loss = svi.evaluate(svi_state, batch) / len(batch) loss_sum += loss return loss_sum loss = lax.fori_loop(0, num_test, body_fun, 0.) loss = loss / num_test return loss def reconstruct_img(epoch, rng_key): img = test_fetch(0, test_idx)[0][0] plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray') rng_key_binarize, rng_key_sample = random.split(rng_key) test_sample = binarize(rng_key_binarize, img) params = svi.get_params(svi_state) z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1])) z = dist.Normal(z_mean, z_var).sample(rng_key_sample) img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28]) plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray') for i in range(args.num_epochs): rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split( rng_key, 4) t_start = time.time() num_train, train_idx = train_init() _, svi_state = epoch_train(svi_state, rng_key_train, train_idx) rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3) num_test, test_idx = test_init() test_loss = eval_test(svi_state, rng_key_test, test_idx) reconstruct_img(i, rng_key_reconstruct) print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))
def main(args): rng = PRNGKey(123) rng, toy_data_rng = jax.random.split(rng) train_data, test_data, true_params = create_toy_data( toy_data_rng, args.num_samples, args.dimensions) train_init, train_fetch = subsample_batchify_data( train_data, batch_size=args.batch_size) test_init, test_fetch = split_batchify_data(test_data, batch_size=args.batch_size) ## Init optimizer and training algorithms optimizer = optimizers.Adam(args.learning_rate) svi = DPSVI(model, guide, optimizer, ELBO(), dp_scale=0.01, clipping_threshold=20., num_obs_total=args.num_samples) rng, svi_init_rng, data_fetch_rng = random.split(rng, 3) _, batchifier_state = train_init(rng_key=data_fetch_rng) sample_batch = train_fetch(0, batchifier_state) svi_state = svi.init(svi_init_rng, *sample_batch) @jit def epoch_train(svi_state, batchifier_state, num_batch): def body_fn(i, val): svi_state, loss = val batch = train_fetch(i, batchifier_state) batch_X, batch_Y = batch svi_state, batch_loss = svi.update(svi_state, batch_X, batch_Y) loss += batch_loss / (args.num_samples * num_batch) return svi_state, loss return lax.fori_loop(0, num_batch, body_fn, (svi_state, 0.)) @jit def eval_test(svi_state, batchifier_state, num_batch, rng): params = svi.get_params(svi_state) def body_fn(i, val): loss_sum, acc_sum = val batch = test_fetch(i, batchifier_state) batch_X, batch_Y = batch loss = svi.evaluate(svi_state, batch_X, batch_Y) loss_sum += loss / (args.num_samples * num_batch) acc_rng = jax.random.fold_in(rng, i) acc = estimate_accuracy(batch_X, batch_Y, params, acc_rng, 1) acc_sum += acc / num_batch return loss_sum, acc_sum return lax.fori_loop(0, num_batch, body_fn, (0., 0.)) ## Train model for i in range(args.num_epochs): t_start = time.time() rng, data_fetch_rng = random.split(rng, 2) num_train_batches, train_batchifier_state = train_init( rng_key=data_fetch_rng) svi_state, train_loss = epoch_train(svi_state, train_batchifier_state, num_train_batches) train_loss.block_until_ready() t_end = time.time() if (i % (args.num_epochs // 10)) == 0: rng, test_rng, test_fetch_rng = random.split(rng, 3) num_test_batches, test_batchifier_state = test_init( rng_key=test_fetch_rng) test_loss, test_acc = eval_test(svi_state, test_batchifier_state, num_test_batches, test_rng) print( "Epoch {}: loss = {}, acc = {} (loss on training set: {}) ({:.2f} s.)" .format(i, test_loss, test_acc, train_loss, t_end - t_start)) # parameters for logistic regression may be scaled arbitrarily. normalize # w (and scale intercept accordingly) for comparison w_true = normalize(true_params[0]) scale_true = jnp.linalg.norm(true_params[0]) intercept_true = true_params[1] / scale_true params = svi.get_params(svi_state) w_post = normalize(params['w_loc']) scale_post = jnp.linalg.norm(params['w_loc']) intercept_post = params['intercept_loc'] / scale_post print("w_loc: {}\nexpected: {}\nerror: {}".format( w_post, w_true, jnp.linalg.norm(w_post - w_true))) print("w_std: {}".format(jnp.exp(params['w_std_log']))) print("") print("intercept_loc: {}\nexpected: {}\nerror: {}".format( intercept_post, intercept_true, jnp.abs(intercept_post - intercept_true))) print("intercept_std: {}".format(jnp.exp(params['intercept_std_log']))) print("") X_test, y_test = test_data rng, rng_acc_true, rng_acc_post = jax.random.split(rng, 3) # for evaluation accuracy with true parameters, we scale them to the same # scale as the found posterior. (gives better results than normalized # parameters (probably due to numerical instabilities)) acc_true = estimate_accuracy_fixed_params(X_test, y_test, w_true, intercept_true, rng_acc_true, 10) acc_post = estimate_accuracy(X_test, y_test, params, rng_acc_post, 10) print( "avg accuracy on test set: with true parameters: {} ; with found posterior: {}\n" .format(acc_true, acc_post))
def init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=2 * math.pi, max_tree_depth=10, run_warmup=True, progbar=True, rng_key=PRNGKey(0)): """ Initializes the HMC sampler. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param int num_warmup: Number of warmup steps; samples generated during warmup are discarded. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool dense_mass: A flag to decide if mass matrix is dense or diagonal (default when ``dense_mass=False``) :param float target_accept_prob: Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8. :param float trajectory_length: Length of a MCMC trajectory for HMC. Default value is :math:`2\\pi`. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. :param bool run_warmup: Flag to decide whether warmup is run. If ``True``, `init_kernel` returns an initial :data:`~numpyro.infer.mcmc.HMCState` that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation. :param bool progbar: Whether to enable progress bar updates. Defaults to ``True``. :param jax.random.PRNGKey rng_key: random key to be used as the source of randomness. """ step_size = lax.convert_element_type( step_size, xla_bridge.canonicalize_dtype(np.float64)) nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth, wa_steps wa_steps = num_warmup trajectory_len = trajectory_length max_treedepth = max_tree_depth z = init_params z_flat, unravel_fn = ravel_pytree(z) momentum_generator = partial(_sample_momentum, unravel_fn) find_reasonable_ss = partial(find_reasonable_step_size, potential_fn, kinetic_fn, momentum_generator) wa_init, wa_update = warmup_adapter( num_warmup, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass, target_accept_prob=target_accept_prob, find_reasonable_step_size=find_reasonable_ss) rng_key_hmc, rng_key_wa = random.split(rng_key) wa_state = wa_init(z, rng_key_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng_key) vv_state = vv_init(z, r) energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, 0, 0., 0., False, wa_state, rng_key_hmc) # TODO: Remove; this should be the responsibility of the MCMC class. if run_warmup and num_warmup > 0: # JIT if progress bar updates not required if not progbar: hmc_state = fori_loop(0, num_warmup, lambda *args: sample_kernel(args[1]), hmc_state) else: with tqdm.trange(num_warmup, desc='warmup') as t: for i in t: hmc_state = jit(sample_kernel)(hmc_state) t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False) return hmc_state
def dy(self, samples, noise_scale=0.4, **args): '''Daily confirmed cases with observation noise''' dy_mean = self.dy_mean(samples, **args) dy = dist.Normal(dy_mean, noise_scale * dy_mean).sample(PRNGKey(11)) return dy
in terms of the speed. It also checks whether or not the inference algorithms give the same result. Author : Aleyna Kara (@karalleyna) ''' import time import jax.numpy as jnp from jax.random import PRNGKey, split, uniform import numpy as np from hmm_lib_log import HMM, hmm_forwards_backwards_log, hmm_viterbi_log, hmm_sample_log import distrax seed = 0 rng_key = PRNGKey(seed) rng_key, key_A, key_B = split(rng_key, 3) # state transition matrix n_hidden, n_obs = 100, 10 A = uniform(key_A, (n_hidden, n_hidden)) A = A / jnp.sum(A, axis=1) # observation matrix B = uniform(key_B, (n_hidden, n_obs)) B = B / jnp.sum(B, axis=1).reshape((-1, 1)) n_samples = 1000 init_state_dist = jnp.ones(n_hidden) / n_hidden seed = 0
from jax.random import PRNGKey import jax.numpy as np from numpyro import sample import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS from mixture_model import NormalMixture from NMC import NMC # Model key rng = PRNGKey(0) # Mixture model def mix_model(data): w = sample("w", dist.Dirichlet((1 / 3) * np.ones(3)), rng_key=rng) mu = sample("mu", dist.Normal(np.zeros(3), np.ones(3)), rng_key=rng) std = sample("std", dist.Gamma(np.ones(3), np.ones(3)), rng_key=rng) sample("obs", NormalMixture(w, mu, std), rng_key=rng, obs=data) # Data for mixture model data_test1 = sample("norm1", dist.Normal(10, 1), rng_key=PRNGKey(0), sample_shape=(1000, )) data_test2 = sample("norm2", dist.Normal(0, 1), rng_key=PRNGKey(0), sample_shape=(1000, ))
def __init__(self, cost, backend='torch'): self.normalized = cost.normalized self.constant_goal = cost.constant_goal self.c = cost self._key = PRNGKey(0)