def loss_fun(params, rng, data, batch_size=None, n=None, loss_type="nlp", reduce="sum"): """ :param batch_size: How large a batch to subselect from the provided data :param n: The total size of the dataset (to multiply batch estimate by) """ assert loss_type in ("nlp", "mse") inputs, targets = data n = inputs.shape[0] if n is None else n if batch_size is not None: rng, rng_batch = random.split(rng) i = random.permutation(rng_batch, n)[:batch_size] inputs, targets = inputs[i], targets[i] preds = apply_fun(params, rng, inputs).squeeze() mean_loss = ( -norm.logpdf(targets.squeeze(), preds, params["noise"]).mean() if loss_type == "nlp" else np.power(targets.squeeze() - preds, 2).mean()) if reduce == "sum": loss = n * mean_loss elif reduce == "mean": loss = mean_loss return loss
def generate_nested_circles(key, n_samples, inner_radius=2, outer_radius=4, noise=0.15): k1, k2, k3, k4 = random.split(key, 4) # Generate the circles inner_t = random.uniform(k1, shape=(n_samples // 2, )) * 2 * jnp.pi inner_circle = inner_radius * jnp.vstack( [jnp.cos(inner_t), jnp.sin(inner_t)]) outer_t = random.uniform(k2, shape=(n_samples // 2, )) * 2 * jnp.pi outer_circle = outer_radius * jnp.vstack( [jnp.cos(outer_t), jnp.sin(outer_t)]) data = jnp.vstack([inner_circle.T, outer_circle.T]) # Keep track of the labels y = jnp.hstack([jnp.zeros(n_samples // 2), jnp.ones(n_samples // 2)]) # Shuffle the data idx = jnp.arange(n_samples) idx = random.permutation(k3, idx) data = data[idx] y = y[idx] data += random.normal(k4, data.shape) * noise return data, y
def gibbs_fn(rng_key, gibbs_sites, hmc_sites, pe): # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree({k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=partial(potential_fn, z_hmc=hmc_sites), idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond(random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, pe) _, gibbs_sites, pe = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites, pe
def test_permutation_invariance(self): num_nodes = 4 num_features = 2 rng = random.PRNGKey(0) # Generate random graph. adjacency = random.randint(rng, (num_nodes, num_nodes), 0, 2) node_feats = random.normal(rng, (num_nodes, num_features)) sources, targets = jnp.where(adjacency) # Get permuted graph. perm = random.permutation(rng, jnp.arange(num_nodes)) node_feats_perm = node_feats[perm] adjacency_perm = adjacency[perm] for j in range(len(adjacency)): adjacency_perm = jax.ops.index_update( adjacency_perm, j, adjacency_perm[j][perm]) sources_perm, targets_perm = jnp.where(adjacency_perm) # Create GNN. _, initial_params = GNN.init( rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets) model = nn.Model(GNN, initial_params) # Feedforward both original and permuted graph. logits = model(node_feats, None, sources, targets) logits_perm = model(node_feats_perm, None, sources_perm, targets_perm) self.assertAllClose(logits[perm], logits_perm, check_dtypes=False)
def loss(params, key): keys = random.split(key, 5) indices = random.permutation(keys[0], jnp.arange(X.shape[0]))[:batch_size] X_batch = X[indices, :] wind_velocity = random.uniform(keys[1], shape=(3, ), minval=jnp.asarray([-200., -200., 0.]), maxval=jnp.asarray([200., 200., 0. ])) / 1000. bottom = random.uniform(keys[2], minval=50., maxval=500.) width = random.uniform(keys[3], minval=40., maxval=300.) l = random.uniform(keys[4], minval=1., maxval=30.) sigma = 1. K = kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) neural_kernel.set_params(params) neural_K = neural_kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) return jnp.mean((K - neural_K)**2) / width**2
def _make_minibatches(self, observations, batch_size, rng_key): ''' Creates minibatches consists of the random permutations of the given observation sequences Parameters ---------- observations : array(N, seq_len) Dataset batch_size : int The number of observation sequences that will be included in each minibatch rng_key : array Random key of shape (2,) and dtype uint32 Returns ------- * array(num_batches, batch_size, max_len) Minibatches ''' num_train = len(observations) perm = permutation(rng_key, num_train) def create_mini_batch(batch_idx): return observations[batch_idx] num_batches = num_train // batch_size batch_indices = perm.reshape((num_batches, -1)) minibatches = vmap(create_mini_batch)(batch_indices) return minibatches
def fit(self, X): opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3) opt_state = opt_init((self.encoder_params, self.decoder_params)) def loss(params, inputs): encoder_params, decoder_params = params enc = self.encoder_apply(encoder_params, X) dec = self.decoder_apply(decoder_params, X) return np.square(inputs - dec).sum() + 1e-3 * np.abs(params).sum() @jit def step(i, opt_state, inputs): params = get_params(opt_state) gradient = grad(loss)(params, inputs) return opt_update(i, gradient, opt_state) print('Training autoencoder...') batch_size, itercount = 32, itertools.count() key = random.PRNGKey(0) for epoch in range(5): temp_key, key = random.split(key) X = random.permutation(temp_key, X) for batch_index in range(0, X.shape[0], batch_size): opt_state = step(next(itercount), opt_state, X[batch_index:batch_index + batch_size]) self.encoder_params, self.decoder_params = get_params(opt_state)
def loss(rng: jnp.ndarray, bij_params: Sequence[jnp.ndarray], bij_fns: Sequence[Callable], deq_params: Sequence[jnp.ndarray], deq_fn: Callable, xon: jnp.ndarray) -> float: """Loss function composed of the evidence lower bound and score matching loss. Args: rng: Pseudo-random number generator seed. bij_params: List of arrays parameterizing the RealNVP bijectors. bij_fns: List of functions that compute the shift and scale of the RealNVP affine transformation. deq_params: Parameters of the mean and scale functions used in the log-normal dequantizer. deq_fn: Function that computes the mean and scale of the dequantization distribution. xon: Observations on O(n). Returns: nelbo: The negative evidence lower bound. """ rng, rng_loss, rng_idx = random.split(rng, 3) idx = random.permutation(rng_idx, len(xon))[:100] xobs = xon[idx] if args.elbo_loss: nelbo = negative_elbo(rng_loss, bij_params, bij_fns, deq_params, deq_fn, xobs) nelbo = nelbo.mean() return nelbo else: log_is = importance_log_density(rng_loss, bij_params, bij_fns, deq_params, deq_fn, args.num_importance, xobs) log_target = log_density(xobs) return jnp.mean(log_target - log_is)
def train_epoch(optimizer, train_ds, batch_size, epoch, rng): """Train for a single epoch.""" train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size perms = random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch perms = perms.reshape((steps_per_epoch, batch_size)) batch_metrics = [] for perm in perms: batch = {k: v[perm] for k, v in train_ds.items()} optimizer, metrics = train_step(optimizer, batch) batch_metrics.append(metrics) # compute mean of metrics across each batch in epoch. batch_metrics_np = jax.device_get(batch_metrics) epoch_metrics_np = { k: onp.mean([metrics[k] for metrics in batch_metrics_np]) for k in batch_metrics_np[0] } logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100) return optimizer, epoch_metrics_np
def data_stream(rng, batch_size, X_train, y_train): num_batches, leftover = divmod(X_train.shape[0], batch_size) while True: temp, rng = random.split(rng) perm = random.permutation(temp, X_train.shape[0]) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield X_train[batch_idx], y_train[batch_idx]
def data_stream(rng, batch_size, X, y): num_complete_batches, leftover = divmod(X.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) while True: temp, rng = random.split(rng) perm = random.permutation(temp, X.shape[0]) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield X[batch_idx], y[batch_idx]
def _make_iaf_args(input_dim, hidden_dims): _, rng_perm = random.split(random.PRNGKey(0)) perm = random.permutation(rng_perm, np.arange(input_dim)) # we use Elu nonlinearity because the default one, Relu, masks out negative hidden values, # which in turn create some zero entries in the lower triangular part of Jacobian. arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=perm, nonlinearity=stax.Elu) _, init_params = arn_init(random.PRNGKey(0), (input_dim,)) return partial(arn, init_params),
def data_stream(key, X, y, batch_size): n_data = len(X) while True: perm_key, key = split(key) perm = permutation(perm_key, n_data) num_batches, mod = divmod(n_data, batch_size) num_batches += 1 if mod else 0 for i in range(num_batches): batch_idx = perm[i * batch_size:min((i + 1) * batch_size, n_data)] yield X[batch_idx], y[batch_idx]
def __iter__(self) -> Iterator[Batch]: starts = np.arange(0, len(self.inputs), self.batch_size) self.key, subkey = random.split(self.key) starts = random.permutation(subkey, starts) for start in starts: end = start + self.batch_size batch_inputs = self.inputs[start:end] batch_targets = self.targets[start:end] yield Batch(batch_inputs, batch_targets)
def data_stream(self, train_images, train_labels, num_train, num_batches, batch_size): """Returns batches of data for training""" key = random.PRNGKey(0) while True: key, subkey = random.split(key) perm = random.permutation(subkey, num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx]
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): # convert to unconstrained values z_hmc = { k: biject_to(prototype_trace[k]["fn"].support).inv(v) for k, v in hmc_sites.items() if k in prototype_trace and prototype_trace[k]["type"] == "sample" } use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 wrapped_model = _wrap_model(model) if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def train_one_epoch(params, X_train, y_train, epoch, r): num_samples = X_train.shape[0] random_sample_idx = random.permutation(random.PRNGKey(epoch), jnp.arange(num_samples)) for idx in tqdm(range(0, num_samples, BATCH_SIZE)): mini_batch_idx = random_sample_idx[idx:idx + BATCH_SIZE] mini_batch_x = X_train[mini_batch_idx] mini_batch_y = y_train[mini_batch_idx] params_grad = grad(mlp_loss, argnums=0)(params, mini_batch_x, mini_batch_y) params, r = apply_grads(params, params_grad, r, LEARNING_RATE) loss = mlp_loss(params, mini_batch_x, mini_batch_y) return loss, params, r
def sample_observations(key, f, n_obs, xmin, xmax, x_noise=0.1, y_noise=3.0): key_x, key_y, key_shuffle = split(key, 3) x_noise = normal(key_x, (n_obs,)) * x_noise y_noise = normal(key_y, (n_obs,)) * y_noise x = jnp.linspace(xmin, xmax, n_obs) + x_noise y = f(x) + y_noise X = np.c_[x, y] shuffled_ixs = permutation(key_shuffle, jnp.arange(n_obs)) x, y = jnp.array(X[shuffled_ixs, :].T) return x, y
def init_fun(rng, input_dim, **kwargs): perm = random.permutation(rng, np.arange(input_dim)) inv_perm = np.argsort(perm) def direct_fun(params, inputs, **kwargs): return inputs[:, perm], np.zeros(inputs.shape[:1]) def inverse_fun(params, inputs, **kwargs): return inputs[:, inv_perm], np.zeros(inputs.shape[:1]) return (), direct_fun, inverse_fun
def testPermutationInteger(self): key = random.PRNGKey(0) x = 100 rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertEqual(perm1.dtype, perm2.dtype) self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely! self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): z_hmc = hmc_sites use_enum = len(set(support_sizes) - set(gibbs_sites)) > 0 if use_enum: from numpyro.contrib.funsor import config_enumerate, enum wrapped_model_ = enum(config_enumerate(wrapped_model), -max_plate_nesting - 1) else: wrapped_model_ = wrapped_model def potential_fn(z_discrete): model_kwargs_ = model_kwargs.copy() model_kwargs_["_gibbs_sites"] = z_discrete return potential_energy(wrapped_model_, model_args, model_kwargs_, z_hmc, enum=use_enum) # get support_sizes of gibbs_sites support_sizes_flat, _ = ravel_pytree( {k: support_sizes[k] for k in gibbs_sites}) num_discretes = support_sizes_flat.shape[0] rng_key, rng_permute = random.split(rng_key) idxs = random.permutation(rng_key, jnp.arange(num_discretes)) def body_fn(i, val): idx = idxs[i] support_size = support_sizes_flat[idx] rng_key, z, pe = val rng_key, z_new, pe_new, log_accept_ratio = proposal_fn( rng_key, z, pe, potential_fn=potential_fn, idx=idx, support_size=support_size) rng_key, rng_accept = random.split(rng_key) # u ~ Uniform(0, 1), u < accept_ratio => -log(u) > -log_accept_ratio # and -log(u) ~ exponential(1) z, pe = cond( random.exponential(rng_accept) > -log_accept_ratio, (z_new, pe_new), identity, (z, pe), identity) return rng_key, z, pe init_val = (rng_key, gibbs_sites, potential_fn(gibbs_sites)) _, gibbs_sites, _ = fori_loop(0, num_discretes, body_fn, init_val) return gibbs_sites
def testPermutationArray(self, dtype, shape): key = random.PRNGKey(0) x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False) self.assertArraysAllClose( x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
def get_batch(sampling, key, X, minibatch_size, iteration): if sampling == 'batch': # Calculate epoch from iteration epoch = iteration // (X.shape[0] // minibatch_size) batch_index = iteration % (X.shape[0] // minibatch_size) batch_index_start = batch_index * minibatch_size # Regular batching if batch_index == 0: temp_key, key = random.split(key) X = random.permutation(temp_key, X) return X[batch_index_start:batch_index_start+minibatch_size], X elif sampling == 'uniform': # Uniform subsampling temp_key, key = random.split(key) X = random.permutation(temp_key, X) return X[:minibatch_size], X elif sampling == 'poisson': # Poisson subsampling temp_key, key = random.split(key) whether = random.uniform(temp_key, (X.shape[0],)) < (minibatch_size / X.shape[0]) return X[whether], X else: raise Exception('Invalid sampling method: {}'.format(sampling))
def epoch_step(opt_state, key): perm = permutation(key, len(observations)) _observatios, _targets = observations[perm], targets[perm] sample_generator = self._sample_minibatches( (_observatios, _targets), batch_size) def train_step(opt_state, i): opt_state, loss = self.update(next(itercount), opt_state, next(sample_generator)) return opt_state, loss opt_state, losses = scan(train_step, opt_state, jnp.arange(num_batches)) return opt_state, losses.mean()
def load_mnist(key, n_train, n_test, shuffle=True): (X, y), (X_test, y_test) = mnist.load_data() n_train = n_train if n_train < len(y) else len(y) n_test = n_test if n_test < len(y_test) else len(y) train_key, test_key = split(key) train_indices = jnp.arange(len(y)) perm = permutation(train_key, train_indices)[:n_train] if shuffle else train_indices[:n_train] train_ds = { "X": jnp.float32(X[perm].reshape(n_train, -1)) / 255., "y": jnp.array(y[perm]) } test_indices = jnp.arange(len(y_test)) perm = permutation(test_key, test_indices)[:n_test] if shuffle else test_indices[:n_test] test_ds = { "X": jnp.float32(X_test[perm].reshape(n_test, -1)) / 255., "y": jnp.array(y_test[perm]) } return train_ds, test_ds
def testPermutationArray(self, dtype): key = random.PRNGKey(0) x = onp.arange(100).astype(dtype) rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2, check_dtypes=True) self.assertEqual(perm1.dtype, perm2.dtype) self.assertFalse(onp.all(perm1 == x)) # seems unlikely! self.assertAllClose(onp.sort(perm1), x, check_dtypes=False) self.assertArraysAllClose(x, onp.arange(100).astype(dtype), check_dtypes=True)
def train_test_split(data, rng=None, n_test=None): """ Create a train-test split """ rng = PRNGKey(42) if rng is None else rng n = len(data.x) rng, rng_perm = split(rng) i = permutation(rng, n) n_test = min(16384, int(0.1 * n)) if n_test is None else n_test if isinstance(n_test, float): n_test = int(n_test * n) n_train = n - n_test i_train, i_test = i[:n_train], i[n_train:] return ( Data(data.x[i_train], data.y[i_train]), Data(data.x[i_test], data.y[i_test]), )
def next_mask(self, prev_sel, size, rng): # Choose the degrees of the next layer max_connection = self.dim - 1 if self.triangular_jacobian == False else self.dim if self.method == "random": sel = random.randint(rng, shape=(size,), minval=min(jnp.min(sel), max_connection), maxval=dim) elif "sequential" in self.method: sel = jnp.arange(size)%max(1, max_connection) + min(1, max_connection) if self.method == "shuffled_sequential": sel = random.permutation(rng, sel) else: assert 0, "Invalid mask method" # Create the new mask mask = (prev_sel[:,None] <= sel).astype(jnp.int32) return mask, sel
def data_preprocessing(): """ Seperates data (spin configurations) into test and training set and generates labels""" rng = random.PRNGKey(0) temperatures = jnp.linspace(1.0, 4.0, 7) temperatures1 = [1.0, 1.5, 3.0, 3.5, 4.0] temperatures2 = [2.0, 2.5] x_train = [] y_train = [] x_test = [] y_test = [] for T in temperatures: configs = jnp.load('data/spins_T%s.npy' % T) magnetization_density = jnp.abs( jnp.array([jnp.sum(config) / config.size for config in configs])) labels = jnp.where(magnetization_density < 0.5, 0, 1) if T in temperatures2: x_test.append(configs) y_test.append(labels) else: indices = random.permutation(rng, labels.size) y_test.append(labels[indices[:int(0.2 * labels.size)]]) y_train.append(labels[indices[int(0.2 * labels.size):]]) x_test.append(configs[indices[:int(0.2 * labels.size)]]) x_train.append(configs[indices[int(0.2 * labels.size):]]) y_test_new = jnp.array(y_test[0]) x_test_new = jnp.array(x_test[0]) for i in range(len(y_test) - 1): y_test_new = jnp.concatenate((y_test_new, y_test[i + 1])) x_test_new = jnp.concatenate((x_test_new, x_test[i + 1])) L = jnp.array(x_train).shape[2] x_test = jnp.array(x_test_new).reshape((-1, L, L, 1)).astype(jnp.float64) y_test = jnp.array(y_test_new).reshape((-1, 1)) x_train = jnp.array(x_train).reshape((-1, L, L, 1)).astype(jnp.float64) y_train = jnp.array(y_train).reshape((-1, 1)) jnp.save('data/x_test.npy', x_test) jnp.save('data/y_test.npy', y_test) jnp.save('data/x_train.npy', x_train) jnp.save('data/y_train.npy', y_train) return x_train, y_train, x_test, y_test
def init_fun(rng, input_dim, **kwargs): perm = random.permutation(rng, np.arange(input_dim)) inv_perm = np.argsort(perm) @ForwardFunction def forward_fun(params, inputs, **kwargs): outputs = inputs[:, perm] log_det = np.zeros(inputs.shape[0]) return outputs, log_det @InverseFunction def inverse_fun(params, inputs, **kwargs): outputs = inputs[:, inv_perm] log_det = np.zeros(inputs.shape[0]) return outputs, log_det return (), forward_fun, inverse_fun