def apply_fun(params, inputs): conv_params, pair_params, conv_block_params, serial_params = params # Apply the primary convolutional layer. conv_out = conv_apply(conv_params, inputs) conv_out = relu(conv_out) # Group all possible pairs. W, b = pair_params pair_1 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,1), dim_nums) + b pair_2 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,2), dim_nums) + b pair_3 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,3), dim_nums) + b pair_4 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,4), dim_nums) + b pair_5 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,5), dim_nums) + b pair_out = jnp.dstack([pair_1, pair_2, pair_3, pair_4, pair_5]) pair_out = relu(pair_out) # Convolutional block. conv_block_out = conv_block_apply(conv_block_params, pair_out) # Residual connection. res_out = conv_block_out + pair_out res_out = relu(res_out) # Forward pass. out = serial_apply(serial_params, res_out) return out
def rgb_to_hsv(rgb_image): # Adapted from the numpy implementation here: https://gist.github.com/PolarNick239/691387158ff1c41ad73c#file-rgb_to_hsv_np-py input_shape = rgb_image.shape rgb_image = rgb_image.reshape(-1, 3) r, g, b = rgb_image[:, 0], rgb_image[:, 1], rgb_image[:, 2] maxc = jnp.maximum(jnp.maximum(r, g), b) minc = jnp.minimum(jnp.minimum(r, g), b) v = maxc deltac = maxc - minc # s = deltac / maxc s = deltac / (maxc + 1e-9) deltac = jnp.where(deltac==0, 1, deltac) print(deltac) # rc = (maxc - r) / deltac # gc = (maxc - g) / deltac # bc = (maxc - b) / deltac rc = (maxc - r) / (deltac + 1e-9) # NOT SURE WHY EXACTLY THIS IS NEEDED TO PREVENT NANS! OTHERWISE NANS CAN OCCUR! gc = (maxc - g) / (deltac + 1e-9) bc = (maxc - b) / (deltac + 1e-9) h = 4.0 + gc - rc h = jnp.where(g==maxc, 2.0 + jnp.where(g == maxc, rc, 0) - jnp.where(g==maxc, bc, 0), h) h = jnp.where(r==maxc, jnp.where(r==maxc, bc, 0) - jnp.where(r==maxc, gc, 0), h) h = jnp.where(minc==maxc, 0.0, h) h = (h / 6.0) % 1.0 res = jnp.dstack([h, s, v]) return res.reshape(input_shape)
def sample(rng, params, num_samples=1): cluster_samples = [] for mean, cov in zip(means, covariances): rng, temp_rng = random.split(rng) cluster_sample = random.multivariate_normal( temp_rng, mean, cov, (num_samples, )) cluster_samples.append(cluster_sample) samples = np.dstack(cluster_samples) idx = random.categorical(rng, weights, shape=(num_samples, 1, 1)) return np.squeeze(np.take_along_axis(samples, idx, -1))
def torsionVecs_(self, P): p0 = P[0] p1 = P[1] p2 = P[2] p3 = P[3] r1 = p0 - p1 r2 = p1 - p2 r3 = p3 - p2 cp_12 = np.cross(r1, r2) cp_32 = np.cross(r3, r2) return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \ .squeeze() \ .transpose([1, 0])
def contour_grid(xmin, xmax, ymin, ymax, n_x, n_y, n_importance_samples=None): x_range, y_range = jnp.linspace(xmin, xmax, 100), jnp.linspace(ymin, ymax, 100) X, Y = jnp.meshgrid(x_range, y_range) XY = jnp.dstack([X, Y]).reshape((-1, 2)) if n_importance_samples is not None: XY = jnp.broadcast_to(XY[None, ...], (n_importance_samples, ) + XY.shape) def reshape_to_grid(Z): return Z.reshape(X.shape) return X, Y, XY, reshape_to_grid
def torsionVecs(self, P): p0 = P[...,[0],[0,1,2]] p1 = P[...,[1],[0,1,2]] p2 = P[...,[2],[0,1,2]] p3 = P[...,[3],[0,1,2]] r1 = p0 - p1 r2 = p1 - p2 r3 = p3 - p2 cp_12 = np.cross(r1, r2) cp_32 = np.cross(r3, r2) return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \ .squeeze() \ .transpose([0, 2, 1])
def generate_image( height, width, scene_camera, world, config, ): """Generates an image of dimensions (height x width x 3) from the given camera.""" def process_pixel( position, num_samples, rng, ): j, i = position def get_color_at_sample(u, v, sample_rng): ray = scene_camera.get_ray(u, v) return compute_color_fn(ray, rng=sample_rng).array() pixel_rng = jax.random.fold_in(rng, width * i + j) pixel_rng, i_rng, j_rng = jax.random.split(pixel_rng, num=3) # Random samples for anti-aliasing. random_is = jax.random.uniform(i_rng, shape=(num_samples, )) random_js = jax.random.uniform(j_rng, shape=(num_samples, )) us = (j + random_js) / width vs = (i + random_is) / height sample_rngs = jax.random.split(pixel_rng, num=num_samples) colors = jax.vmap(get_color_at_sample)(us, vs, sample_rngs) colors = jnp.mean(colors, axis=0) return colors num_samples = config.num_antialiasing_samples rng = jax.random.PRNGKey(config.rng_seed) compute_color_fn = functools.partial(compute_color, world=world, config=config) process_pixel_fn = functools.partial(process_pixel, num_samples=num_samples, rng=rng) process_pixel_fn = jax.vmap(jax.vmap(process_pixel_fn)) grid = jnp.dstack(jnp.meshgrid(jnp.arange(width), jnp.arange(height))) image = process_pixel_fn(grid) return image
def _dstack_product(x, y): """Returns the cartesian product of the elements of x and y vectors. Args: x: 1d array y: 1d array of the same dtype as x. Returns: a 2D array containing the elements of [x]x[y]. Example: x = jnp.array([1, 2, 3]) y = jnp.array([4, 5] _dstack_product(x,y) >>> [[1, 4], [2, 4], [3, 4], [1, 5], [2, 5], [3, 5]] """ return jnp.dstack(jnp.meshgrid(x, y, indexing="ij")).reshape(-1, 2)
def generate_grid(key, n_samples, min_val, max_val, n_clusters_per_axis): x, y = jnp.linspace(min_val, max_val, n_clusters_per_axis), jnp.linspace( min_val, max_val, n_clusters_per_axis) X, Y = jnp.meshgrid(x, y) xy = jnp.dstack([X, Y]).reshape((-1, 2)) # Repeat the data so that we can add noise to different copies n_repeats = n_samples // (n_clusters_per_axis**2) data = jnp.repeat(xy, repeats=n_repeats, axis=0) # Add just enough noise so that we see each cluster without overlapping std = (max_val - min_val) / n_clusters_per_axis * 0.25 noise = random.normal(key, data.shape) * std data += noise data = random.permutation(key, data) return data
def predict(self, X, y=None, p=None): """ Parameters ========== X : array_like, shape (n_samples, n_features) Stimulus design matrix. y : None or array_like, shape (n_samples, ) Recorded response. Needed when post-spike filter is fitted. p : None or dict Model parameters. Only needed if model performance is monitored during training. """ if self.n_c > 1: XS = jnp.dstack([X[:, :, i] @ self.S for i in range(self.n_c) ]).reshape(X.shape[0], -1) else: XS = X @ self.S extra = {'X': X, 'XS': XS, 'y': y} if self.h_spl is not None: if y is None: raise ValueError( '`y` is needed for calculating response history.') yh = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.Sh.shape[0], shift=self.shift_h)) yS = yh @ self.Sh extra.update({'yS': yS}) params = self.p_opt if p is None else p y_pred = self.forwardpass(params, extra=extra) return y_pred
def make_gradient_field(function, xrange=(-1, 2), yrange=(-1, 2), n_points=30, shape=(2, 1)): W = jnp.linspace(*xrange, n_points) B = jnp.linspace(*yrange, n_points) U, V = jnp.meshgrid(W, B) pairs = jnp.dstack([U, V]).reshape(-1, *shape) vectorized_fun = jit(vmap(function)) Z = vectorized_fun(pairs).reshape(n_points, n_points) grad_fun = jit(vmap(grad(function))) gradvals = grad_fun(pairs) gradx = gradvals[:, 0].reshape(n_points, n_points) grady = gradvals[:, 1].reshape(n_points, n_points) gradnorm = jnp.sqrt(gradx**2 + grady**2) return U, V, Z, pairs, gradvals, gradx, grady, gradnorm
def _sample_next(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): new_rng, rng = jax.random.split(state.rng) # def cbr(data): # new_rng, rng = data # print("sample_next newrng:\n", new_rng, "\nand rng:\n", rng) # return new_rng # new_rng = hcb.call( # cbr, # (new_rng, rng), # result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype), # ) with loops.Scope() as s: s.key = rng s.σ = state.σ s.log_prob = sampler.machine_pow * machine(parameters, state.σ).real s.beta = state.beta # for logging s.beta_0_index = state.beta_0_index s.n_accepted_per_beta = state.n_accepted_per_beta s.beta_position = state.beta_position s.beta_diffusion = state.beta_diffusion for i in s.range(sampler.n_sweeps): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5) # def cbi(data): # i, beta = data # print("sweep #", i, " for beta=\n", beta) # return beta # beta = hcb.call( # cbi, # (i, s.beta), # result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype), # ) beta = s.beta σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s.σ) proposal_log_prob = sampler.machine_pow * machine( parameters, σp).real uniform = jax.random.uniform(key2, shape=(sampler.n_batches, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s.log_prob + log_prob_correction)) else: do_accept = uniform < jnp.exp( beta.reshape( (-1, )) * (proposal_log_prob - s.log_prob)) # do_accept must match ndim of proposal and state (which is 2) s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ) n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape( (sampler.n_chains, sampler.n_replicas)) s.log_prob = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s.log_prob) # exchange betas # randomly decide if every set of replicas should be swapped in even or odd order swap_order = jax.random.randint( key3, minval=0, maxval=2, shape=(sampler.n_chains, ), ) # 0 or 1 iswap_order = jnp.mod(swap_order + 1, 2) # 1 or 0 # indices of even swapped elements (per-row) idxs = jnp.arange(0, sampler.n_replicas, 2).reshape( (1, -1)) + swap_order.reshape((-1, 1)) # indices off odd swapped elements (per-row) inn = (idxs + 1) % sampler.n_replicas # for every rows of the input, swap elements at idxs with elements at inn @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def swap_rows(beta_row, idxs, inn): proposed_beta = jax.ops.index_update( beta_row, idxs, beta_row[inn], unique_indices=True, indices_are_sorted=True, ) proposed_beta = jax.ops.index_update( proposed_beta, inn, beta_row[idxs], unique_indices=True, indices_are_sorted=False, ) return proposed_beta proposed_beta = swap_rows(beta, idxs, inn) @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def compute_proposed_prob(prob, idxs, inn): prob_rescaled = prob[idxs] + prob[inn] return prob_rescaled # compute the probability of the swaps log_prob = (proposed_beta - state.beta) * s.log_prob.reshape( (sampler.n_chains, sampler.n_replicas)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) uniform = jax.random.uniform(key4, shape=(sampler.n_chains, sampler.n_replicas // 2)) do_swap = uniform < prob_rescaled do_swap = jnp.dstack((do_swap, do_swap)).reshape( (-1, sampler.n_replicas)) # concat along last dimension # roll if swap_ordeer is odd @partial(jax.vmap, in_axes=(0, 0), out_axes=0) def fix_swap(do_swap, swap_order): return jax.lax.cond(swap_order == 0, lambda x: x, lambda x: jnp.roll(x, 1), do_swap) do_swap = fix_swap(do_swap, swap_order) # jax.experimental.host_callback.id_print(state.beta) # jax.experimental.host_callback.id_print(proposed_beta) new_beta = jax.numpy.where(do_swap, proposed_beta, beta) def cb(data): _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data print("--------.---------.---------.--------") print(" cur beta:\n", _bt) print("proposed beta:\n", _pbt) print(" new beta:\n", new_beta) print("swaporder :", so) print("do_swap :\n", do_swap) print("log_prob;\n", log_prob) print("prob_rescaled;\n", prob) return new_beta # new_beta = hcb.call( # cb, # ( # beta, # proposed_beta, # new_beta, # swap_order, # do_swap, # log_prob, # prob_rescaled, # ), # result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype), # ) # s.beta = new_beta swap_order = swap_order.reshape(-1) beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i], in_axes=(0, 0), out_axes=0)(do_swap, state.beta_0_index) proposed_beta_0_index = jnp.mod( state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) * (-jnp.mod(state.beta_0_index, 2) * 2 + 1), sampler.n_replicas, ) s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index, s.beta_0_index) # swap acceptances swapped_n_accepted_per_beta = swap_rows( n_accepted_per_beta, idxs, inn) s.n_accepted_per_beta = jax.numpy.where( do_swap, swapped_n_accepted_per_beta, n_accepted_per_beta, ) # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed delta = s.beta_0_index - s.beta_position s.beta_position = s.beta_position + delta / ( state.exchange_steps + i) delta2 = s.beta_0_index - s.beta_position s.beta_diffusion = s.beta_diffusion + delta * delta2 new_state = state.replace( rng=new_rng, σ=s.σ, # n_accepted=s.accepted, n_samples=state.n_samples + sampler.n_sweeps * sampler.n_chains, beta=s.beta, beta_0_index=s.beta_0_index, beta_position=s.beta_position, beta_diffusion=s.beta_diffusion, exchange_steps=state.exchange_steps + sampler.n_sweeps, n_accepted_per_beta=s.n_accepted_per_beta, ) offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas, sampler.n_replicas) return new_state, new_state.σ[new_state.beta_0_index + offsets, :]
def loop_body(i, s): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s["key"], key1, key2, key3, key4 = jax.random.split(s["key"], 5) # def cbi(data): # i, beta = data # print("sweep #", i, " for beta=\n", beta) # return beta # # beta = hcb.call( # cbi, # (i, s["beta"]), # result_shape=jax.ShapeDtypeStruct(s["beta"].shape, s["beta"].dtype), # ) beta = s["beta"] σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s["σ"]) proposal_log_prob = sampler.machine_pow * machine.apply( parameters, σp).real uniform = jax.random.uniform(key2, shape=(sampler.n_batches, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s["log_prob"] + log_prob_correction)) else: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s["log_prob"])) # do_accept must match ndim of proposal and state (which is 2) s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"]) n_accepted_per_beta = s["n_accepted_per_beta"] + do_accept.reshape( (sampler.n_chains, sampler.n_replicas)) s["log_prob"] = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s["log_prob"]) # exchange betas # randomly decide if every set of replicas should be swapped in even or odd order swap_order = jax.random.randint( key3, minval=0, maxval=2, shape=(sampler.n_chains, ), ) # 0 or 1 # iswap_order = jnp.mod(swap_order + 1, 2) # 1 or 0 # indices of even swapped elements (per-row) idxs = jnp.arange(0, sampler.n_replicas, 2).reshape( (1, -1)) + swap_order.reshape((-1, 1)) # indices off odd swapped elements (per-row) inn = (idxs + 1) % sampler.n_replicas # for every rows of the input, swap elements at idxs with elements at inn @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def swap_rows(beta_row, idxs, inn): proposed_beta = beta_row.at[idxs].set(beta_row[inn], unique_indices=True, indices_are_sorted=True) proposed_beta = proposed_beta.at[inn].set( beta_row[idxs], unique_indices=True, indices_are_sorted=False) return proposed_beta proposed_beta = swap_rows(beta, idxs, inn) @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def compute_proposed_prob(prob, idxs, inn): prob_rescaled = prob[idxs] + prob[inn] return prob_rescaled # compute the probability of the swaps log_prob = (proposed_beta - state.beta) * s["log_prob"].reshape( (sampler.n_chains, sampler.n_replicas)) prob_rescaled = jnp.exp(compute_proposed_prob(log_prob, idxs, inn)) uniform = jax.random.uniform(key4, shape=(sampler.n_chains, sampler.n_replicas // 2)) do_swap = uniform < prob_rescaled do_swap = jnp.dstack((do_swap, do_swap)).reshape( (-1, sampler.n_replicas)) # concat along last dimension # roll if swap_ordeer is odd @partial(jax.vmap, in_axes=(0, 0), out_axes=0) def fix_swap(do_swap, swap_order): return jax.lax.cond(swap_order == 0, lambda x: x, lambda x: jnp.roll(x, 1), do_swap) do_swap = fix_swap(do_swap, swap_order) # jax.experimental.host_callback.id_print(state.beta) # jax.experimental.host_callback.id_print(proposed_beta) # new_beta = jax.numpy.where(do_swap, proposed_beta, beta) # def cb(data): # _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data # print("--------.---------.---------.--------") # print(" cur beta:\n", _bt) # print("proposed beta:\n", _pbt) # print(" new beta:\n", new_beta) # print("swaporder :", so) # print("do_swap :\n", do_swap) # print("log_prob;\n", log_prob) # print("prob_rescaled;\n", prob) # return new_beta # # new_beta = hcb.call( # cb, # ( # beta, # proposed_beta, # new_beta, # swap_order, # do_swap, # log_prob, # prob_rescaled, # ), # result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype), # ) # s["beta"] = new_beta swap_order = swap_order.reshape(-1) beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i], in_axes=(0, 0), out_axes=0)(do_swap, state.beta_0_index) proposed_beta_0_index = jnp.mod( state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) * (-jnp.mod(state.beta_0_index, 2) * 2 + 1), sampler.n_replicas, ) s["beta_0_index"] = jnp.where(beta_0_moved, proposed_beta_0_index, s["beta_0_index"]) # swap acceptances swapped_n_accepted_per_beta = swap_rows(n_accepted_per_beta, idxs, inn) s["n_accepted_per_beta"] = jax.numpy.where( do_swap, swapped_n_accepted_per_beta, n_accepted_per_beta, ) # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed delta = s["beta_0_index"] - s["beta_position"] s["beta_position"] = s["beta_position"] + delta / ( state.exchange_steps + jnp.asarray(i, dtype=jnp.int64)) delta2 = s["beta_0_index"] - s["beta_position"] s["beta_diffusion"] = s["beta_diffusion"] + delta * delta2 return s
def dstack(arrays): arrays = [a.value if isinstance(a, JaxArray) else a for a in arrays] return JaxArray(jnp.dstack(arrays))
def evaluate_2d_model(create_model, args, classification=False): assert args.save_path.endswith(".pickle") == False init_key = random.PRNGKey(args.init_key_seed) train_key = random.PRNGKey(args.train_key_seed) eval_key = random.PRNGKey(args.eval_key_seed) train_ds, get_test_ds = get_dataset(args.dataset, args.batch_size, args.n_batches, args.test_batch_size, args.test_n_batches, quantize_bits=args.quantize_bits, classification=classification, label_keep_percent=1.0, random_label_percent=0.0) doubly_batched_inputs = next(train_ds) inputs = {"x": doubly_batched_inputs["x"][0]} if "y" in doubly_batched_inputs: inputs["y"] = doubly_batched_inputs["y"][0] flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, )) outputs = flow.apply(init_key, inputs) print("n_params", flow.n_params) trainer = initialize_trainer(flow, clip=args.clip, lr=args.lr, warmup=args.warmup, cosine_decay_steps=args.cosine_decay_steps, save_path=args.save_path, retrain=args.retrain, train_args=args.train_args, classification=classification) test_losses = sorted(trainer.test_losses.items(), key=lambda x: x[0]) test_losses = jnp.array(test_losses) test_ds = get_test_ds() res = trainer.evaluate_test(eval_key, test_ds) print("test", trainer.summarize_losses_and_aux(res)) # Plot samples samples = flow.sample(eval_key, n_samples=5000, manifold_sample=True) # Find the spread of the data data = doubly_batched_inputs["x"].reshape((-1, 2)) (xmin, ymin), (xmax, ymax) = data.min(axis=0), data.max(axis=0) xspread, yspread = xmax - xmin, ymax - ymin xmin -= 0.25 * xspread xmax += 0.25 * xspread ymin -= 0.25 * yspread ymax += 0.25 * yspread # Plot the samples against the true samples and also a dentisy plot if "prediction" in samples: fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(28, 7)) ax1.scatter(*data.T) ax1.set_title("True Samples") ax2.scatter(*samples["x"].T, alpha=0.2, s=3, c=samples["prediction"]) ax2.set_title("Learned Samples") ax1.set_xlim(xmin, xmax) ax1.set_ylim(ymin, ymax) ax2.set_xlim(xmin, xmax) ax2.set_ylim(ymin, ymax) n_importance_samples = 100 x_range, y_range = jnp.linspace(xmin, xmax, 100), jnp.linspace(ymin, ymax, 100) X, Y = jnp.meshgrid(x_range, y_range) XY = jnp.dstack([X, Y]).reshape((-1, 2)) XY = jnp.broadcast_to(XY[None, ...], (n_importance_samples, ) + XY.shape) outputs = flow.scan_apply(eval_key, {"x": XY}) outputs["log_px"] = jax.scipy.special.logsumexp( outputs["log_px"], axis=0) - jnp.log(n_importance_samples) outputs["prediction"] = jnp.mean(outputs["prediction"], axis=0) Z = jnp.exp(outputs["log_px"]) ax3.contourf(X, Y, Z.reshape(X.shape)) ax4.contourf(X, Y, outputs["prediction"].reshape(X.shape)) plt.show() else: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(21, 7)) ax1.scatter(*data.T) ax1.set_title("True Samples") ax2.scatter(*samples["x"].T, alpha=0.2, s=3) ax2.set_title("Learned Samples") ax1.set_xlim(xmin, xmax) ax1.set_ylim(ymin, ymax) ax2.set_xlim(xmin, xmax) ax2.set_ylim(ymin, ymax) n_importance_samples = 100 x_range, y_range = jnp.linspace(xmin, xmax, 100), jnp.linspace(ymin, ymax, 100) X, Y = jnp.meshgrid(x_range, y_range) XY = jnp.dstack([X, Y]).reshape((-1, 2)) XY = jnp.broadcast_to(XY[None, ...], (n_importance_samples, ) + XY.shape) outputs = flow.scan_apply(eval_key, {"x": XY}) outputs["log_px"] = jax.scipy.special.logsumexp( outputs["log_px"], axis=0) - jnp.log(n_importance_samples) Z = jnp.exp(outputs["log_px"]) ax3.contourf(X, Y, Z.reshape(X.shape)) plt.show() assert 0
def fit(self, p0=None, extra=None, num_subunits=2, num_epochs=1, num_iters=3000, initialize='random', metric=None, alpha=1, beta=0.05, fit_linear_filter=True, fit_intercept=True, fit_R=True, fit_history_filter=False, fit_nonlinearity=False, step_size=1e-2, tolerance=10, verbose=100, random_seed=2046, return_model=None): self.metric = metric self.alpha = alpha # elastic net parameter (1=L1, 0=L2) self.beta = beta # elastic net parameter - global penalty weight self.n_s = num_subunits self.num_iters = num_iters self.fit_linear_filter = fit_linear_filter self.fit_history_filter = fit_history_filter self.fit_nonlinearity = fit_nonlinearity self.fit_intercept = fit_intercept self.fit_R = fit_R # initialize parameters if p0 is None: p0 = {} dict_keys = p0.keys() if 'b' not in dict_keys: if initialize == 'random': # not necessary, but for consistency with others. key = random.PRNGKey(random_seed) b0 = 0.01 * random.normal( key, shape=(self.n_b * self.n_c * self.n_s, )).flatten() p0.update({'b': b0}) if 'intercept' not in dict_keys: p0.update({'intercept': jnp.zeros(1)}) if 'R' not in dict_keys: p0.update({'R': jnp.array([1.])}) if 'bh' not in dict_keys: try: p0.update({'bh': self.bh_spl}) except: p0.update({'bh': None}) if 'nl_params' not in dict_keys: if self.nl_params is not None: p0.update({ 'nl_params': [self.nl_params for _ in range(self.n_s + 1)] }) else: p0.update({'nl_params': [None for _ in range(self.n_s + 1)]}) if extra is not None: if self.n_c > 1: XS_ext = jnp.dstack([ extra['X'][:, :, i] @ self.S for i in range(self.n_c) ]).reshape(extra['X'].shape[0], -1) extra.update({'XS': XS_ext}) else: extra.update({'XS': extra['X'] @ self.S}) if self.h_spl is not None: yh = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.Sh.shape[0], shift=1)) yS = yh @ self.Sh extra.update({'yS': yS}) extra = {key: jnp.array(extra[key]) for key in extra.keys()} self.p0 = p0 self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters, metric, step_size, tolerance, verbose, return_model) self.R = self.p_opt['R'] if fit_R else jnp.array([1.]) if fit_linear_filter: self.b_opt = self.p_opt['b'] if self.n_c > 1: self.w_opt = jnp.stack([(self.S @ self.b_opt.reshape( self.n_b, self.n_c, self.n_s)[:, :, i]) for i in range(self.n_s)], axis=-1) else: self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_s) if fit_history_filter: self.bh_opt = self.p_opt['bh'] self.h_opt = self.Sh @ self.bh_opt if fit_intercept: self.intercept = self.p_opt['intercept'] if fit_nonlinearity: self.nl_params_opt = self.p_opt['nl_params']
def __init__(self, X, y, dims, df, smooth='cr', compute_mle=False, **kwargs): """ Parameters ========== X : array_like, shape (n_samples, n_features) Stimulus design matrix. y : array_like, shape (n_samples, ) Recorded response. dims : list or array_like, shape (ndims, ) Dimensions or shape of the RF to estimate. Assumed order [t, sx, sy]. df : list or array_like, shape (ndims, ) Degree of freedom, or the number of basis used for each RF dimension. smooth : str Type of basis. * cr: natural cubic spline (default) * cc: cyclic cubic spline * bs: B-spline * tp: thin plate spine compute_mle : bool Compute sta and maximum likelihood optionally. """ super().__init__(X, y, dims, compute_mle, **kwargs) # Optimization self.bh_opt = None self.b_opt = None self.extra = None self.h_spl = None self.bh_spl = None self.yS = None self.Sh = None # Parameters self.df = df # number basis / degree of freedom self.smooth = smooth # type of basis S = jnp.array(build_spline_matrix(self.dims, df, smooth)) # for w if self.n_c > 1: XS = jnp.dstack([self.X[:, :, i] @ S for i in range(self.n_c) ]).reshape(self.n_samples, -1) else: XS = self.X @ S self.S = S # spline matrix self.XS = XS self.n_b = S.shape[1] # num:ber of spline coefficients # compute spline-based maximum likelihood self.b_spl = jnp.linalg.lstsq(XS.T @ XS, XS.T @ y, rcond=None)[0] if self.n_c > 1: self.w_spl = S @ self.b_spl.reshape(self.n_b, self.n_c) else: self.w_spl = S @ self.b_spl
def fit(self, p0=None, extra=None, initialize='random', num_epochs=1, num_iters=3000, metric=None, alpha=1, beta=0.05, fit_linear_filter=True, fit_intercept=True, fit_R=True, fit_history_filter=False, fit_nonlinearity=False, step_size=1e-2, tolerance=10, verbose=100, random_seed=2046, return_model=None): """ Parameters ========== p0 : dict * 'b': Initial spline coefficients. * 'bh': Initial response history filter coefficients initialize : None or str Parametric initialization. * if `initialize=None`, `b` will be initialized by b_spl. * if `initialize='random'`, `b` will be randomly initialized. num_iters : int Max number of optimization iterations. metric : None or str Extra cross-validation metric. Default is `None`. Or * 'mse': mean squared error * 'r2': R2 score * 'corrcoef': Correlation coefficient alpha : float, from 0 to 1. Elastic net parameter, balance between L1 and L2 regularization. * 0.0 -> only L2 * 1.0 -> only L1 beta : float Elastic net parameter, overall weight of regularization for receptive field. step_size : float Initial step size for JAX optimizer. tolerance : int Set early stop tolerance. Optimization stops when cost monotonically increases or stop increases for tolerance=n steps. verbose: int When `verbose=0`, progress is not printed. When `verbose=n`, progress will be printed in every n steps. """ self.metric = metric self.alpha = alpha self.beta = beta # elastic net parameter - global penalty weight for linear filter self.num_iters = num_iters self.fit_R = fit_R self.fit_linear_filter = fit_linear_filter self.fit_history_filter = fit_history_filter self.fit_nonlinearity = fit_nonlinearity self.fit_intercept = fit_intercept # initial parameters if p0 is None: p0 = {} dict_keys = p0.keys() if 'b' not in dict_keys: if initialize is None: p0.update({'b': self.b_spl}) else: if initialize == 'random': key = random.PRNGKey(random_seed) b0 = 0.01 * random.normal( key, shape=(self.n_b * self.n_c, )).flatten() p0.update({'b': b0}) if 'intercept' not in dict_keys: p0.update({'intercept': jnp.array([0.])}) if 'R' not in dict_keys: p0.update({'R': jnp.array([1.])}) if 'bh' not in dict_keys: if initialize is None and self.bh_spl is not None: p0.update({'bh': self.bh_spl}) elif initialize == 'random' and self.bh_spl is not None: key = random.PRNGKey(random_seed) bh0 = 0.01 * random.normal(key, shape=(len( self.bh_spl), )).flatten() p0.update({'bh': bh0}) else: p0.update({'bh': None}) if 'nl_params' not in dict_keys: if self.nl_params is not None: p0.update({'nl_params': self.nl_params}) else: p0.update({'nl_params': None}) if extra is not None: if self.n_c > 1: XS_ext = jnp.dstack([ extra['X'][:, :, i] @ self.S for i in range(self.n_c) ]).reshape(extra['X'].shape[0], -1) extra.update({'XS': XS_ext}) else: extra.update({'XS': extra['X'] @ self.S}) if self.h_spl is not None: yh_ext = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.Sh.shape[0], shift=1)) yS_ext = yh_ext @ self.Sh extra.update({'yS': yS_ext}) extra = {key: jnp.array(extra[key]) for key in extra.keys()} self.extra = extra # store for cross-validation # store optimized parameters self.p0 = p0 self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters, metric, step_size, tolerance, verbose, return_model) self.R = self.p_opt['R'] if fit_R else jnp.array([1.]) if fit_linear_filter: self.b_opt = self.p_opt['b'] # optimized RF basis coefficients if self.n_c > 1: self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_c) else: self.w_opt = self.S @ self.b_opt # optimized RF if fit_history_filter: self.bh_opt = self.p_opt['bh'] self.h_opt = self.Sh @ self.bh_opt if fit_nonlinearity: self.nl_params_opt = self.p_opt['nl_params'] if fit_intercept: self.intercept = self.p_opt['intercept']
def _plot(fig, ax1, ax2, mean, sigma, array_samples_theta, interactive=False): colorbar = None colorbar_2 = None plt.gca() # plt.cla() # plt.clf() fig.clear() fig.add_axes(ax1) fig.add_axes(ax2) plt.cla() xlim = (-5., 5.) ylim = (-5., 5.) xlist = np.linspace(*xlim, 100) ylist = np.linspace(*ylim, 100) X_, Y_ = np.meshgrid(xlist, ylist) Z = np.dstack((X_, Y_)) Z = Z.reshape(-1, 2) predictions = onp.mean(probability_class_1(Z, array_samples_theta), axis=1) predictions = predictions.reshape(100, 100) # print("finished") ax1.clear() if np.size(predictions): CS = ax1.contourf(X_, Y_, predictions, cmap="cividis") ax1.scatter(X_1[:, 0], X_1[:, 1]) ax1.scatter(X_2[:, 0], X_2[:, 1]) ax1.set_xlim(*xlim) ax1.set_ylim(*ylim) ax1.set_title("Predicted probability of belonging to C_1") ax3 = fig.add_axes(Bbox([[0.43, 0.11], [0.453, 0.88]])) if np.size(predictions): colorbar = fig.colorbar( CS, cax=ax3, ) ax1.set_position(Bbox([[0.125, 0.11], [0.39, 0.88]])) x_prior = np.linspace(-3, 3, 100) y_prior = np.linspace(-3, 3, 100) X_prior, Y_prior = np.meshgrid(x_prior, y_prior) Z = np.dstack((X_prior, Y_prior)) Z = Z.reshape(-1, 2) prior_values = multivariate_normal.pdf(Z, np.zeros(2), np.identity(2)) prior_values = prior_values.reshape(100, 100) std_x = onp.sqrt(sigma[0, 0]) std_y = onp.sqrt(sigma[1, 1]) x_posterior = np.linspace(mean[0] - 3 * std_x, mean[0] + 3 * std_x, 100) y_posterior = np.linspace(mean[1] - 3 * std_y, mean[1] + 3 * std_y, 100) X_post, Y_post = np.meshgrid(x_posterior, y_posterior) Z_post = np.dstack((X_post, Y_post)).reshape(-1, 2) posterior_values = multivariate_normal.pdf(Z_post, mean, sigma) posterior_values = posterior_values.reshape(100, 100) ax2.contour(X_post, Y_post, posterior_values) ax2.contour(X_, Y_, prior_values, cmap="inferno") ax2.set_title("Two contour plots respectively showing\n" "The prior and the approximated posterior distributions") plt.pause(0.001) if interactive: if np.size(predictions): colorbar.remove() return True