def test_multiple_observed_rv(self): import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS y1 = np.random.randn(10) y2 = np.random.randn(100) def model_example_multiple_obs(y1=None, y2=None): x = numpyro.sample("x", dist.Normal(1, 3)) numpyro.sample("y1", dist.Normal(x, 1), obs=y1) numpyro.sample("y2", dist.Normal(x, 1), obs=y2) nuts_kernel = NUTS(model_example_multiple_obs) mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2) mcmc.run(PRNGKey(0), y1=y1, y2=y2) inference_data = from_numpyro(mcmc) test_dict = { "posterior": ["x"], "sample_stats": ["diverging"], "log_likelihood": ["y1", "y2"], "observed_data": ["y1", "y2"], } fails = check_multiple_attrs(test_dict, inference_data) # from ..stats import waic # waic_results = waic(inference_data) # print(waic_results) # print(waic_results.keys()) # print(waic_results.waic, waic_results.waic_se) assert not fails assert not hasattr(inference_data.sample_stats, "log_likelihood")
def conditional_from_guide(self, guide, params, *args, **kwargs): pred_noise, diag = kwargs.pop("pred_noise", False), kwargs.pop("diag", False) self._get_var_names(*args, **kwargs) predictive = Predictive( self.model, guide=guide, params=params, num_samples=self.num_samples, return_sites=( self.gp, self.mean, self.cond, self.Kss, self.Kns, self.Ksx, self.Kxx, self.Knx, self.y, ), ) self.cond_params = predictive(PRNGKey(self.rng_key), *args) mu, var = self._build_conditional(self.cond_params, pred_noise, diag) return mu, var
def test_threshold_preschedule(self): threshold_schedule = jnp.linspace(10, 0.1, 10) sampler = MetropolisedABCSMCSampler(threshold_schedule=threshold_schedule) sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value[-1]) self._test_cov(sample.value[-1])
def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype # randomly initialized parameters random_params = self.init_weights(self.key, input_shape) # save required_params as set self._required_params = set( flatten_dict(unfreeze(random_params)).keys()) self.params = random_params
def rand_ket(N, seed=None): r"""Returns a random :math:`N`-dimensional ket. Args: N (int): Dimension of random ket Reurns: :obj:`jnp.ndarray`: random :math:`N \times 1` dimensional vector (ket) """ if seed == None: seed = np.random.randint(1000) ket = uniform(PRNGKey(seed), (N, 1)) + 1j * uniform(PRNGKey(seed), (N, 1)) return ket / jnp.linalg.norm(ket)
def rand_unitary(N, seed=None): r"""Returns an :math:`N \times N` randomly parametrized unitary Args: N (int): Size of the Hilbert space Returns: :obj:`jnp.ndarray`: :math:`N \times N` parameterized random unitary matrix .. note:: JAX provides Psuedo-Random Number Generator Keys (PRNG Keys) that aim to ensure reproducibility. `seed` integer here is fed as input to a PRNGKey that returns of array of shape (2,) for every different input integer seed. PRNGKey for the same input integer shall sample the same values from any distribution. """ if seed == None: seed = np.random.randint(1000) params = uniform(PRNGKey(seed), (N ** 2,), minval=0.0, maxval=2 * jnp.pi) rand_thetas = params[: N * (N - 1) // 2] rand_phis = params[N * (N - 1) // 2 : N * (N - 1)] rand_omegas = params[N * (N - 1) :] return Unitary(N)(rand_thetas, rand_phis, rand_omegas)
def test_Parameter(Parameter=Parameter): scalar = Parameter(lambda _: np.zeros(())) params = scalar.init_parameters(PRNGKey(0)) assert np.zeros(()) == params out = scalar.apply(params) assert params == out
def mnist_data(n_obs, rng_key=None): ''' Downloads data from tensorflow datasets Parameters ---------- n_obs : int Number of digits randomly chosen from mnist rng_key : array Random key of shape (2,) and dtype uint32 Returns ------- * array((n_obs, 784)) Dataset ''' rng_key = PRNGKey(0) if rng_key is None else rng_key (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() x = (x_train > 0).astype('int') # Converting to binary dataset_size = x.shape[0] perm = randint(rng_key, minval=0, maxval=dataset_size, shape=((n_obs,))) x_train = x[perm] x_train = x_train.reshape((n_obs, 784)) return x_train
def test_adaptive_diag_stepsize(self): sampler = RandomWalkABC() sampler.tuning.target = 0.1 sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0), correction=RMMetropolisDiagStepsize(rm_stepsize_scale=0.1)) self._test_mean(sample.value) npt.assert_almost_equal(sample.alpha.mean(), sampler.tuning.target, decimal=1)
def main(): step_size = 0.001 num_epochs = 100 batch_size = 32 test_key = PRNGKey( 1 ) # get reconstructions for a *fixed* latent variable sample over time train_images, test_images = mnist_images() num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) opt = optimizers.Momentum(step_size, mass=0.9) @jit def binarize_batch(key, i, images): i = i % num_batches batch = lax.dynamic_slice_in_dim(images, i * batch_size, batch_size) return random.bernoulli(key, batch) @jit def run_epoch(key, state): def body_fun(i, state): loss_key, data_key = random.split(random.fold_in(key, i)) batch = binarize_batch(data_key, i, train_images) return opt.update(loss.apply, state, batch, key=loss_key) return lax.fori_loop(0, num_batches, body_fun, state) example_key = PRNGKey(0) example_batch = binarize_batch(example_key, 0, images=train_images) shaped_elbo = loss.shaped(example_batch) init_parameters = shaped_elbo.init_parameters(key=PRNGKey(2)) state = opt.init(init_parameters) for epoch in range(num_epochs): tic = time.time() state = run_epoch(PRNGKey(epoch), state) params = opt.get_parameters(state) test_elbo, samples = evaluate.apply_from({shaped_elbo: params}, test_images, key=test_key, jit=True) print( f'Epoch {epoch: 3d} {test_elbo:.3f} ({time.time() - tic:.3f} sec)') from matplotlib import pyplot as plt plt.imshow(samples, cmap=plt.cm.gray) plt.show()
def forecast(self, num_samples=1000, rng_key=PRNGKey(4), **args): if self.mcmc_samples is None: raise RuntimeError("run inference first") predictive = Predictive(self, posterior_samples=self.mcmc_samples) args = dict(self.args, **args) return predictive(rng_key, **self.obs, **args)
def test_adaptive(self): retain_parameter = 0.8 sampler = MetropolisedABCSMCSampler(ess_threshold_retain=retain_parameter) sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value[-1]) self._test_cov(sample.value[-1])
def prior_sample(self, X=None, num_samps=1, key=0): if X is None: X = self.X N = X.shape[0] m = np.zeros(N) K = self.kernel(X, X) + 1e-12 * np.eye(N) s = multivariate_normal(PRNGKey(key), m, K, shape=[num_samps]) return s.T
def test_Dropout_shape(mode, input_shape=(1, 2, 3)): dropout = Dropout(.9, mode=mode) inputs = np.zeros(input_shape) out = dropout(inputs, PRNGKey(0)) assert np.array_equal(np.zeros(input_shape), out) out_ = dropout(inputs, rng=PRNGKey(0)) assert np.array_equal(out, out_) try: dropout(inputs) assert False except ValueError as e: assert 'dropout requires to be called with a PRNG key argument. ' \ 'That is, instead of `dropout(params, inputs)`, ' \ 'call it like `dropout(inputs, key)` ' \ 'where `key` is a jax.random.PRNGKey value.' == str(e)
def prior(self, num_samples=1000, rng_key=PRNGKey(2), **args): predictive = Predictive(self, posterior_samples={}, num_samples=num_samples) args = dict(self.args, **args) # passed args take precedence self.prior_samples = predictive(rng_key, **args) return self.prior_samples
def test_post_threshold(self): acceptance_rate = 0.1 sampler = VanillaABC(acceptance_rate=acceptance_rate) sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value[sample.log_weight > -jnp.inf], 10.) npt.assert_almost_equal(jnp.mean(sample.log_weight == 0), acceptance_rate, decimal=3) self.assertNotEqual(sampler.parameters.threshold, jnp.inf)
def svi_predict(model, guide, params, args, X): predictive = Predictive(model=model, guide=guide, params=params, num_samples=args.num_samples) predictions = predictive(PRNGKey(1), X=X, Y=None) svi_predictions = jnp.rint(predictions["Y"].mean(0)) return svi_predictions
def init_emission_probs(self, mixing_coeffs, probs, dataset, targets, rng_key=None, num_of_iter=7): """ Parameters ---------- mixing_coeffs : array The probabilities of mixture_distribution of ClassConditionalBMM probs : array The probabilities of components_distribution of ClassConditionalBMM dataset : array Dataset targets : The ground-truth labels of the dataset rng_key : array Random key of shape (2,) and dtype uint32 num_of_iter The number of iterations the training process that takes place Returns ------- """ class_priors = np.zeros( (self.word_len, self.n_char)) # observation likelihoods for i in range(self.word_len): class_priors[i] = self.emission_prob_(self.word[i]) if (mixing_coeffs is None or probs is None) and (dataset is not None and targets is not None): mixing_coeffs = jnp.full((self.n_char - 1, self.n_mix), 1. / self.n_mix) if rng_key is None: rng_key = PRNGKey(0) probs = uniform(rng_key, minval=0.4, maxval=0.6, shape=(self.n_char - 1, self.n_mix, dataset.shape[-1])) class_conditional_bmm = ClassConditionalBMM( mixing_coeffs, probs, jnp.array(class_priors), self.n_char - 1) class_conditional_bmm.fit_em(dataset, targets, num_of_iter) self._obs_dist = class_conditional_bmm else: self._obs_dist = ClassConditionalBMM(mixing_coeffs, probs, jnp.array(class_priors), self.n_char - 1)
def hmm_em_jax(observations, valid_lengths, n_hidden=None, n_obs=None, init_params=None, priors=None, num_epochs=1, rng_key=None): ''' Implements Baum–Welch algorithm which is used for finding its components, A, B and pi. Parameters ---------- observations: array All observation sequences valid_lengths : array Valid lengths of each observation sequence n_hidden : int The number of hidden states n_obs : int The number of observable events init_params : HMMJax Initial Hidden Markov Model priors : PriorsJax Priors for the components of Hidden Markov Model num_epochs : int Number of times model will be trained rng_key : array Random key of shape (2,) and dtype uint32 Returns ---------- * HMMJax Trained Hidden Markov Model * array Negative loglikelihoods each of which can be interpreted as the loss value at the current iteration. ''' if rng_key is None: rng_key = PRNGKey(0) if init_params is None: try: init_params = init_random_params_jax([n_hidden, n_obs], rng_key=rng_key) except: raise ValueError("n_hidden and n_obs should be specified when init_params was not given.") epochs = jnp.arange(num_epochs) def train_step(params, epoch): trans_counts, obs_counts, init_counts, ll = hmm_e_step_jax(params, observations, valid_lengths) params = hmm_m_step_jax([trans_counts, obs_counts, init_counts], priors) return params, -ll final_params, neg_loglikelihoods = jax.lax.scan(train_step, init_params, epochs) return final_params, neg_loglikelihoods
def test_save_and_load_params(): params = Dense(2).init_parameters(np.zeros((1, 2)), key=PRNGKey(0)) from pathlib import Path path = Path('/') / 'tmp' / 'net.params' save(params, path) params_ = load(path) assert_dense_parameters_equal(params, params_)
def test_mnist_vae(): @parametrized def encode(input): input = Sequential(Dense(5), relu, Dense(5), relu)(input) mean = Dense(10)(input) variance = Sequential(Dense(10), softplus)(input) return mean, variance decode = Sequential(Dense(5), relu, Dense(5), relu, Dense(5 * 5)) @parametrized def elbo(key, images): mu_z, sigmasq_z = encode(images) logits_x = decode(gaussian_sample(key, mu_z, sigmasq_z)) return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z) params = elbo.init_parameters(PRNGKey(0), np.zeros((32, 5 * 5)), key=PRNGKey(0)) assert (5, 10) == params.encode.sequential1.dense.kernel.shape
def test_Parameter_with_multiple_arrays(Parameter=Parameter): two_scalars = Parameter(lambda _: (np.zeros(()), np.zeros(()))) params = two_scalars.init_parameters(key=PRNGKey(0)) a, b = params assert np.zeros(()) == a assert np.zeros(()) == b out = two_scalars.apply(params) assert params == out
def test_dict_input(): @parametrized def net(input_dict): return input_dict['a'] * input_dict['b'] * parameter((), zeros) inputs = {'a': np.zeros(2), 'b': np.zeros(2)} params = net.init_parameters(inputs, key=PRNGKey(0)) out = net.apply(params, inputs) assert np.array_equal(np.zeros(2), out)
def test_tuple_input(): @parametrized def net(input_dict): return input_dict[0] * input_dict[1] * parameter((), zeros, input_dict[0]) inputs = (np.zeros((2,)), np.zeros((2,))) params = net.init_parameters(PRNGKey(0), inputs) out = net.apply(params, inputs) assert np.array_equal(np.zeros((2, 10)), out)
def parameters_from(self, reuse, *example_inputs): expanded_reuse = parametrized._expand_reuse_dict( reuse, *example_inputs) # TODO: optimization wrong, duplicate values, needs param adapter return self.init_parameters(PRNGKey(0), *example_inputs, reuse=expanded_reuse, reuse_only=True)
def test_mlstm1900(): """Test forward pass of pre-built mlstm1900 model""" init_fun, model_fun = mlstm1900() _, params = init_fun(PRNGKey(42), input_shape=(-1, 26)) oh = seq_to_oh("HASTA") out = model_fun(params, oh) assert out.shape == (7, 25)
def test_Conv1DTranspose_runs(channels, filter_shape, padding, strides, input_shape): convt = Conv1DTranspose(channels, filter_shape, strides=strides, padding=padding) inputs = random_inputs(input_shape) params = convt.init_parameters(PRNGKey(0), inputs) convt.apply(params, inputs)
def main(batch_size=256, env_name="CartPole-v1"): env = gym.make(env_name) policy = Sequential(Dense(64), relu, Dense(env.action_space.n)) @parametrized def loss(observations, actions, rewards_to_go): logprobs = log_softmax(policy(observations)) action_logprobs = logprobs[np.arange(logprobs.shape[0]), actions] return -np.mean(action_logprobs * rewards_to_go, axis=0) opt = Adam() shaped_loss = loss.shaped(np.zeros((1, ) + env.observation_space.shape), np.array([0]), np.array([0])) @jit def sample_action(state, key, observation): loss_params = opt.get_parameters(state) logits = policy.apply_from({shaped_loss: loss_params}, observation) return sample_categorical(key, logits) rng_init, rng = random.split(PRNGKey(0)) state = opt.init(shaped_loss.init_parameters(key=rng_init)) returns, observations, actions, rewards_to_go = [], [], [], [] for i in range(250): while len(observations) < batch_size: observation = env.reset() episode_done = False rewards = [] while not episode_done: rng_step, rng = random.split(rng) action = sample_action(state, rng_step, observation) observations.append(observation) actions.append(action) observation, reward, episode_done, info = env.step(int(action)) rewards.append(reward) returns.append(onp.sum(rewards)) rewards_to_go += list(onp.flip(onp.cumsum(onp.flip(rewards)))) print(f'Batch {i}, recent mean return: {onp.mean(returns[-100:]):.1f}') state = opt.update(loss.apply, state, np.array(observations[:batch_size]), np.array(actions[:batch_size]), np.array(rewards_to_go[:batch_size]), jit=True) observations = observations[batch_size:] actions = actions[batch_size:] rewards_to_go = rewards_to_go[batch_size:]
def generate_params(N, key=PRNGKey(0)): """Generator for generating parameterizing angles in `make_unitary`""" for _ in range(3): key, subkey = split(key) thetas = uniform( subkey, ((N * (N - 1) // 2),), minval=0.0, maxval=2 * jnp.pi ) phis = uniform(subkey, ((N * (N - 1) // 2),), minval=0.0, maxval=2 * jnp.pi) omegas = uniform(subkey, (N,), minval=0.0, maxval=2 * jnp.pi) yield thetas, phis, omegas
def test_regularized_submodule(): net = Sequential(Conv(2, (1, 1)), relu, Conv(2, (1, 1)), relu, flatten, L2Regularized(Sequential(Dense(2), relu, Dense(2), np.sum), .1)) input = np.ones((1, 3, 3, 1)) params = net.init_parameters(input, key=PRNGKey(0)) assert (2, 2) == params.regularized.model.dense1.kernel.shape out = net.apply(params, input) assert () == out.shape