def nearest_sampler(imgs, coords, mask_value): """Construct a new image by nearest sampling from the input image. Points falling outside the source image boundary have value of mask_value. Args: imgs: source image to be sampled from [b, h, w, c] coords: coordinates of source pixels to sample from [b, h, w, 2]. height_t/width_t correspond to the dimensions of the output image (don't need to be the same as height_s/width_s). The two channels correspond to x and y coordinates respectively. mask_value: value of points outside of image. -1 for edge sampling. Returns: A new sampled image [height_t, width_t, channels] """ coords_x, coords_y = jnp.split(coords, 2, axis=2) inp_size = imgs.shape out_size = list(coords.shape) out_size[2] = imgs.shape[2] coords_x = jnp.array(coords_x, dtype='float32') coords_y = jnp.array(coords_y, dtype='float32') y_max = jnp.array(jnp.shape(imgs)[0] - 1, dtype='float32') x_max = jnp.array(jnp.shape(imgs)[1] - 1, dtype='float32') zero = jnp.zeros([1], dtype='float32') eps = jnp.array([0.5], dtype='float32') coords_x_clipped = jnp.clip(coords_x, zero - eps, x_max + eps) coords_y_clipped = jnp.clip(coords_y, zero - eps, y_max + eps) x0 = jnp.round(coords_x_clipped) y0 = jnp.round(coords_y_clipped) x0_safe = jnp.clip(x0, zero, x_max) y0_safe = jnp.clip(y0, zero, y_max) # indices in the flat image to sample from dim2 = jnp.array(inp_size[1], dtype='float32') base_y0 = y0_safe * dim2 idx00 = jnp.reshape(x0_safe + base_y0, [-1]) # sample from imgs imgs_flat = jnp.reshape(imgs, [-1, inp_size[2]]) imgs_flat = imgs_flat.astype('float32') output = jnp.reshape( jnp.take(imgs_flat, idx00.astype('int32'), axis=0), out_size ) return jnp.where( jnp.any(mask_value > 0), jnp.where( compute_mask(coords_x, coords_y, x_max, y_max), output, jnp.ones_like(output) * jnp.reshape(jnp.array(mask_value), [1, 1, -1]) ), output)
def loss_fn(param_dict, signal): params = param_dict["nn"] sigma = param_dict["s"] hf = 1 N = int(jnp.round(6*sigma)) # Adding some more noise during training to prevent classifier from overfitting on irrelevant aspects of the spectra signal = signal + 0.2*np.random.randn(signal.shape[0]) x = diff_stft(signal, s = sigma,hf = hf) li = [] l1 = jnp.array([[1,0]]) l2 = jnp.array([[0,1]]) l_c = [] for i in range(x.shape[1]): timi = i*int(hf*N)/fs d1 = np.min(np.abs(I1 - timi)) d2 = np.min(np.abs(I2 - timi)) if(d1 < d2): li.append(1) l_c.append(l1) else: li.append(2) l_c.append(l2) li = np.array(li) l_c = np.concatenate(l_c,axis = 0).T xzp = jnp.concatenate([x,jnp.zeros((Nzp - (N//2 + 1),x.shape[1]))],axis = 0) logits = net.apply(params,xzp.T) # Regularized loss (Cross entropy + regularizer to avoid small windows) cel = -jnp.mean(logits*l_c.T) + (0.1/sigma) return cel
def testLossAndGradientsAreFinite(self): # Test that the loss and its approximation both give finite losses and # derivatives everywhere that they should for a wide range of values. num_samples = 100000 rng = random.PRNGKey(0) # Normally distributed inputs. rng, key = random.split(rng) x = random.normal(key, shape=[num_samples]) # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1 # to ensure that we hit the special cases at 0, 2. rng, key = random.split(rng) alpha = jnp.round( random.uniform(key, shape=[num_samples], minval=-16, maxval=3) * 10) / 10. # Random log-normally distributed values in approx (1e-5, 100000): rng, key = random.split(rng) scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5 fn = self.variant(general.lossfun) loss = fn(x, alpha, scale) d_x, d_alpha, d_scale = (jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)), [0, 1, 2])(x, alpha, scale)) for v in [loss, d_x, d_alpha, d_scale]: chex.assert_tree_all_finite(v)
def create_stepped_learning_rate_fn(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0): """Create a stepped learning rate function. Args: base_learning_rate: base learning rate steps_per_epoch: number of steps per epoch lr_sched_steps: learning rate schedule as a list of pairs where each pair is `[step, lr_factor]` warmup_length: linear LR warmup length; 0 for no warmup Returns: function of the form f(step) -> learning_rate """ boundaries = [step[0] for step in lr_sched_steps] decays = [step[1] for step in lr_sched_steps] boundaries = jnp.array(boundaries) * steps_per_epoch boundaries = jnp.round(boundaries).astype(jnp.int32) values = jnp.array([1.0] + decays) * base_learning_rate def step_fn(step): lr = piecewise_constant(boundaries, values, step) if warmup_length > 0.0: lr = lr * jnp.minimum( 1., step / float(warmup_length) / steps_per_epoch) return lr return step_fn
def input_wavefront(self, wavelength=1e-6): """Create a Wavefront object suitable for sending through a given optical system. Uses self.source_offset to assign an off-axis tilt, if requested. (FIXME does not work for Fresnel yet) Parameters ---------- wavelength : float Wavelength in meters Returns ------- wavefront : morphine.fresnel.FresnelWavefront instance A wavefront appropriate for passing through this optical system. """ oversample = int(np.round(1 / self.beam_ratio)) inwave = FresnelWavefront(self.pupil_diameter / 2, wavelength=wavelength, npix=self.npix, oversample=oversample) # _log.debug( # "Creating input wavefront with wavelength={0} microns," # "npix={1}, diam={3}, pixel scale={2}".format( # wavelength * 1e6, self.npix, self.pupil_diameter / (self.npix), self.pupil_diameter # )) inwave._display_hint_expected_nplanes = len( self) # For displaying a multi-step calculation nicely return inwave
def topk_mask_internal(value): assert value.ndim == 1 indices = jnp.argsort(value) k = jnp.round(density_fraction * jnp.size(value)).astype(jnp.int32) mask = jnp.greater_equal(np.arange(value.size), value.size - k) mask = jnp.zeros_like(mask).at[indices].set(mask) return mask.astype(np.int32)
def Epot(pos, *args): M, L = args pos = pos.reshape((3, M)) energy = 0 for i in range(M - 1): for j in list(range(i + 1, M)): deltaX = pos[0, i] - pos[0, j] deltaXmi = deltaX - L * np.round(deltaX / L) deltaY = pos[1, i] - pos[1, j] deltaYmi = deltaY - L * np.round(deltaY / L) deltaZ = pos[2, i] - pos[2, j] deltaZmi = deltaZ - L * np.round(deltaZ / L) r = np.linalg.norm([deltaXmi, deltaYmi, deltaZmi]) energy += Vlj(r) return energy
def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity()
def test_logistic_regression(self): key = random.PRNGKey(0) N, n = 5, 2 key, k1, k2, k3 = random.split(key, num=4) X_np = random.normal(k1, shape=(N, n)) a_true = random.normal(k2, shape=(n, 1)) y_np = jnp.round( sigmoid(X_np @ a_true + random.normal(k3, shape=(N, 1)) * 0.5)) X_jax = jnp.array(X_np) lam_jax = 0.1 * jnp.ones(1) a = cp.Variable((n, 1)) X = cp.Parameter((N, n)) lam = cp.Parameter(1, nonneg=True) y = y_np log_likelihood = cp.sum( cp.multiply(y, X @ a) - cp.log_sum_exp( cp.hstack([np.zeros((N, 1)), X @ a]).T, axis=0, keepdims=True).T) prob = cp.Problem( cp.Minimize(-log_likelihood + lam * cp.sum_squares(a))) fit_logreg = CvxpyLayer(prob, [X, lam], [a]) check_grads(fit_logreg, (X_jax, lam_jax), order=1, modes=['rev'])
def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST): """ The Hermitian matrix `H` is split into two matrices `Hm` `Hp`, respectively sharing its eigenspaces beneath and above its `split_point`th eigenvalue. Returns, in addition, `Vm` and `Vp`, isometries such that `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are returned instead; this allows the overall isometries mapping from an initial input matrix to progressively smaller blocks to be formed. Args: H: The Hermitian matrix to split. split_point: The eigenvalue to split along. V0: Matrix of isometries to be updated. precision: TPU matmul precision. Returns: Hm: A Hermitian matrix sharing the eigenvalues of `H` beneath `split_point`. Vm: An isometry from the input space of `V0` to `Hm`. Hp: A Hermitian matrix sharing the eigenvalues of `H` above `split_point`. Vp: An isometry from the input space of `V0` to `Hp`. """ def _fill_diagonal(X, vals): return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals) H_shift = _fill_diagonal(H, H.diagonal() - split_point) U, _ = jsp.linalg.polar_unitary(H_shift) P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.) rank = jnp.round(jnp.trace(P)).astype(jnp.int32) rank = int(rank) return _split_spectrum_jittable(P, H, V0, rank, precision)
def grassman_distance(y1, y2): """Grassman distance between subspaces spanned by Y1 and Y2.""" q1, _ = jnp.linalg.qr(y1) q2, _ = jnp.linalg.qr(y2) _, sigma, _ = jnp.linalg.svd(q1.T @ q2) sigma = jnp.round(sigma, decimals=6) return jnp.linalg.norm(jnp.arccos(sigma))
def compute_grassman_distance(Y1, Y2): """Grassman distance between subspaces spanned by Y1 and Y2.""" Q1, _ = jnp.linalg.qr(Y1) Q2, _ = jnp.linalg.qr(Y2) _, sigma, _ = jnp.linalg.svd(Q1.T @ Q2) sigma = jnp.round(sigma, decimals=6) return jnp.linalg.norm(jnp.arccos(sigma))
def testRoundStaticDecimals(self, shape, dtype, decimals, rng): if onp.issubdtype(dtype, onp.integer) and decimals < 0: self.skipTest("Integer rounding with decimals < 0 not implemented") onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def testRoundStaticDecimals(self, shape, dtype, decimals, rng): onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def grassman_distance(Y1, Y2): # pylint: disable=invalid-name """Grassman distance between subspaces spanned by Y1 and Y2.""" Q1, _ = jnp.linalg.qr(Y1) # pylint: disable=invalid-name Q2, _ = jnp.linalg.qr(Y2) # pylint: disable=invalid-name _, sigma, _ = jnp.linalg.svd(Q1.T @ Q2) # sigma = jnp.clip(sigma, -1., 1.) sigma = jnp.round(sigma, decimals=6) return jnp.linalg.norm(jnp.arccos(sigma))
def jax_invech(v): ''' Inverse half vectorization operator ''' rows = int(jnp.round(.5 * (-1 + jnp.sqrt(1 + 8 * len(v))))) res = jnp.zeros((rows, rows)) res = jax.ops.index_update(res, jnp.triu_indices(rows), v) res = res + res.T - jnp.diag(jnp.diag(res)) return res
def save_images(batch, fname): n_rows = batch.shape[0] // 16 batch = onp.uint8(jnp.round((batch + 1) * 127.5)) out = onp.full((1 + 33 * n_rows, 1 + 33 * 16, 3), 255, 'uint8') for i, im in enumerate(batch): top = 1 + 33 * (i // 16) left = 1 + 33 * (i % 16) out[top:top + 32, left:left + 32] = im Image.fromarray(out).save(fname)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm, parameterization, use_dropout): fc = partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) conv = partial(stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization) affine = conv(width) if is_conv else fc(width) rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') ave_pool = stax.AvgPool((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') ave_pool_or_identity = ave_pool if use_pooling else stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm)) res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial(affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention(width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def predict(params, state, action_field=None, action_size=2, action_layer=[2, 3]): """ Predict the next state give the args: params: network parameters state: current state Returns: predicted_state: next state prediction action: action to take """ # Date is a mutable variable that will hold the intermediatery states between layers data = state i = 0 action_data = np.array([]) for w, b in params[:-1]: data = np.add(np.dot(w, data), b) i += 1 if action_field and i in action_layer: action_data.append(data) try: sin_cut = wandb.config.sin_cut except KeyError: wandb.config.sin_cut = 0.001 if action_field: assert len(action_field) == 2 for i in range(0, len(action_field or [])): # All action_fields arrays must equal action_data size assert len(action_field[i]) == len(action_data) # action_field has two sets of parameters # Fast GPU noise # http://people.compute.dtu.dk/jerf/papers/abstracts/noise_abstract.pdf action = np.round( np.dot( 1 / np.dot(len(action_field)), np.sin(np.dot(action_data, action_field[1])), )) else: action = None final_w, final_b = params[-1] predicted_state = np.tanh(np.dot(final_w, data)) # TODO: Make this come out of a noise function action = 0.0 return predicted_state, action
def f(params, pot_ini): x = params[:-1] v = params[-1] vprint = jnp.round(v, 2) print("evaluating for V = {:.2f}".format(vprint)) eff, pot = simulator.eff_at_bias(convr(x), v, pot_ini, verbose=False) return -eff, pot
def letter_seq(arr: np.array) -> str: """ Convert a 2D one-hot array into a string representation. TODO: More docstrings needed. """ sequence = "" for letter in arr: sequence += arr_to_letter(np.round(letter)) return sequence.strip("start").strip("stop")
def sample(self, key, params): sample_x = distribution_utils.sample_from_discretized_mix_logistic_rgb( key, params, self.n_mixtures) # range [-1., 1.] sample_x = (sample_x + 1.) / 2. # range [0, 1.] sample_x = sample_x * (self.n_classes - 1.) # range [0, n_classes - 1] # Better round now, otherwise we get floor division when cast to int32. sample_x = jnp.round(sample_x) return sample_x
def sample_time_jump_with_linear_increase(step, num_train_steps, min_jump, max_jump, rng): """Returns a stochastic jump size, with linearly increasing mean.""" max_time_jump_for_step = min_jump + (step / (num_train_steps - 1)) * (max_jump - min_jump) max_time_jump_for_step = jnp.round(max_time_jump_for_step) jump = jax.random.randint(rng, (), min_jump, max_time_jump_for_step + 1) jump = int(jump) return jump
def uoro_grad(self, key, theta, state, s_tilde=None, theta_tilde=None): epsilon_perturbation = 1e-7 epsilon_stability = 1e-7 total_theta_grad = 0 total_loss = 0.0 if s_tilde is None: s_tilde = jnp.zeros(state.inner_state.shape) if theta_tilde is None: theta_tilde = jnp.zeros(theta.shape) state_old = state # TODO: How do we handle key here? Do we want to split again? loss, state_new = self.unroll_fn(key, theta, state_old, self.T, 1) total_loss += loss dl_dstate_old = self.compute_dL_dstate_old(theta, state_old) dl_dtheta_direct = self.compute_dL_dtheta_direct(theta, state_old) dl_dstate_old = dl_dstate_old.inner_state indirect_grad = (dl_dstate_old * s_tilde).sum() * theta_tilde pseudograds = indirect_grad + dl_dtheta_direct state_old_perturbed = state_old._replace( inner_state=state_old.inner_state + s_tilde * epsilon_perturbation) state_new_perturbed = self.f(theta, state_old_perturbed) state_deriv_in_direction_s_tilde = ( (state_new_perturbed - state_new.inner_state) / epsilon_perturbation) nus = jnp.round(jax.random.uniform( key, state_old.inner_state.shape)) * 2 - 1 custom_f = lambda param_vector: self.f(param_vector, state_old) primals, f_vjp = jax.vjp(custom_f, theta) direct_theta_tilde_contribution, = f_vjp(nus) rho_0 = jnp.sqrt((jnp.linalg.norm(theta_tilde) + epsilon_stability) / (jnp.linalg.norm(state_deriv_in_direction_s_tilde) + epsilon_stability)) rho_1 = jnp.sqrt( (jnp.linalg.norm(direct_theta_tilde_contribution) + epsilon_stability) / (jnp.linalg.norm(nus) + epsilon_stability)) theta_grad = pseudograds total_theta_grad += theta_grad s_tilde = rho_0 * state_deriv_in_direction_s_tilde + rho_1 * nus theta_tilde = theta_tilde / rho_0 + direct_theta_tilde_contribution / rho_1 return (total_loss, state_new, s_tilde, theta_tilde), total_theta_grad
def apply_bond_charge_corrections(initial_charges, bond_idxs, deltas): """For an arbitrary collection of ordered bonds and associated increments `(a, b, delta)`, update `charges` by `charges[a] += delta`, `charges[b] -= delta` Notes ----- * preserves sum(initial_charges) for arbitrary values of bond_idxs or deltas * order within each row of bond_idxs is meaningful `(..., bond_idxs, deltas)` means the opposite of `(..., bond_idxs[:, ::-1], deltas)` * order within the first axis of bond_idxs, deltas is not meaningful `(..., bond_idxs[perm], deltas[perm])` means the same thing for any permutation `perm` """ # apply bond charge corrections incremented = ops.index_add(initial_charges, bond_idxs[:, 0], +deltas) decremented = ops.index_add(incremented, bond_idxs[:, 1], -deltas) final_charges = decremented # make some safety assertions assert bond_idxs.shape[1] == 2 assert len(deltas) == len(bond_idxs) net_charge = jnp.sum(initial_charges) net_charge_is_integral = jnp.isclose(net_charge, jnp.round(net_charge), atol=1e-5) final_net_charge = jnp.sum(final_charges) net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-5) assert net_charge_is_integral assert net_charge_is_unchanged # print some safety warnings directed_bonds = Counter([tuple(b) for b in bond_idxs]) undirected_bonds = Counter([tuple(sorted(b)) for b in bond_idxs]) if max(directed_bonds.values()) > 1: duplicates = [ bond for (bond, count) in directed_bonds.items() if count > 1 ] print(UserWarning(f"Duplicate directed bonds! {duplicates}")) elif max(undirected_bonds.values()) > 1: duplicates = [ bond for (bond, count) in undirected_bonds.items() if count > 1 ] print(UserWarning(f"Duplicate undirected bonds! {duplicates}")) return final_charges
def sample_mask_indices(input_dim, hidden_dim): """ Samples the indices assigned to hidden units during the construction of MADE masks :param input_dim: the dimensionality of the input variable :type input_dim: int :param hidden_dim: the dimensionality of the hidden layer :type hidden_dim: int """ indices = jnp.linspace(1, input_dim, num=hidden_dim) # Simple procedure tries to space fractional indices evenly by rounding to nearest int return jnp.round(indices)
def diff_stft(xinp, s, hf=0.5): """ Inputs ------ xinp: jnp.array Input audio signal in time domain s: jnp.float The standard deviation of the Gaussian window to be used hf: jnp.float The fraction of window size that will be overlapped within consecutive frames Outputs ------- a: jnp.array The computed magnitude spectrogram """ # Effective window length of Gaussian is 6\sigma sz = s * 6 hp = hf * sz # Truncating to integers for use in jnp functions intsz = int(jnp.round(sz)) inthp = int(jnp.round(hp)) m = jnp.arange(0, intsz, dtype=jnp.float32) # Obtaining the "differentiable" window function by using the real valued \sigma window = jnp.exp(-0.5 * jnp.power((m - sz / 2) / (s + 1e-5), 2)) window_norm = window / jnp.sum(window) # Computing the STFT, and taking its magnitude stft = jnp.sqrt(1 / (2 * window_norm.shape[0] + 1)) * jnp.stack([ jnp.fft.rfft(window_norm * xinp[i:i + intsz]) for i in range(0, len(xinp) - intsz, inthp) ], 1) a = jnp.abs(stft) return a
def _precompute_lossfun_inputs(self): """Precompute a loss and its derivatives for random inputs and parameters. Generates a large number of random inputs to the loss, and random shape/scale parameters for the loss function at each sample, and computes the loss and its derivative with respect to all inputs and parameters, returning everything to be used to assert various properties in our unit tests. Returns: A tuple containing: (the number (int) of samples, and the length of all following arrays, A tensor of losses for each sample, A tensor of residuals of each sample (the loss inputs), A tensor of shape parameters of each loss, A tensor of scale parameters of each loss, A tensor of derivatives of each loss wrt each x, A tensor of derivatives of each loss wrt each alpha, A tensor of derivatives of each loss wrt each scale) Typical usage example: (num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale) = self._precompute_lossfun_inputs() """ num_samples = 100000 rng = random.PRNGKey(0) # Normally distributed inputs. rng, key = random.split(rng) x = random.normal(key, shape=[num_samples]) # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1 # to ensure that we hit the special cases at 0, 2. rng, key = random.split(rng) alpha = jnp.round( random.uniform(key, shape=[num_samples], minval=-16, maxval=3) * 10) / 10. # Push the sampled alphas at the extents of the range to +/- infinity, so # that we probe those cases too. alpha = jnp.where(alpha == 3, jnp.inf, alpha) alpha = jnp.where(alpha == -16, -jnp.inf, alpha) # Random log-normally distributed values in approx (1e-5, 100000): rng, key = random.split(rng) scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5 fn = self.variant(general.lossfun) loss = fn(x, alpha, scale) d_x, d_alpha, d_scale = (jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)), [0, 1, 2])(x, alpha, scale)) return (num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale)
def body_func(args): xi, accumulated_sum = args xi_float = jnp.asarray(xi, dtype=dtype) log_xi_factorial = lax.lgamma(xi_float + 1.) log_comb_n_xi = (log_n_factorial - log_xi_factorial - lax.lgamma(total_count - xi_float + 1.)) comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi)) likelihood1 = math.power_no_nan(probs, xi) likelihood2 = math.power_no_nan(1. - probs, total_count - xi) likelihood = likelihood1 * likelihood2 comb_term = comb_n_xi * log_xi_factorial * likelihood # [K] chex.assert_shape(comb_term, (probs.shape[-1], )) return xi + 1, accumulated_sum + comb_term
def uoro_grad(key, theta, state, s_tilde=None, theta_tilde=None): epsilon_perturbation = 1e-7 epsilon_stability = 1e-7 total_theta_grad = 0 total_loss = 0.0 if s_tilde is None: s_tilde = jnp.zeros(state.shape) if theta_tilde is None: theta_tilde = jnp.zeros(theta.shape) state_old = state # (23,) state_new = f(theta, state_old) # (23,) loss = L(theta, state_old) total_loss += loss dl_dstate_old = compute_dL_dstate_old(theta, state_old) # (23,) dl_dtheta_direct = compute_dL_dtheta_direct(theta, state_old) # (1,) indirect_grad = (dl_dstate_old * s_tilde).sum() * theta_tilde # (1,) pseudograds = indirect_grad + dl_dtheta_direct # (1,) state_old_perturbed = state_old + s_tilde * epsilon_perturbation # (23,) state_new_perturbed = f(theta, state_old_perturbed) # (23,) state_deriv_in_direction_s_tilde = ( state_new_perturbed - state_new) / epsilon_perturbation # (23,) nus = jnp.round(jax.random.uniform(key, state_old.shape)) * 2 - 1 # (23,) # Tricky part is this first line custom_f = lambda param_vector: f(param_vector, state_old) primals, f_vjp = jax.vjp(custom_f, theta) direct_theta_tilde_contribution, = f_vjp(nus) # (1,) rho_0 = jnp.sqrt((jnp.linalg.norm(theta_tilde) + epsilon_stability) / (jnp.linalg.norm(state_deriv_in_direction_s_tilde) + epsilon_stability)) rho_1 = jnp.sqrt( (jnp.linalg.norm(direct_theta_tilde_contribution) + epsilon_stability) / (jnp.linalg.norm(nus) + epsilon_stability)) theta_grad = pseudograds total_theta_grad += theta_grad s_tilde = rho_0 * state_deriv_in_direction_s_tilde + rho_1 * nus theta_tilde = theta_tilde / rho_0 + direct_theta_tilde_contribution / rho_1 return (total_loss, state_new, s_tilde, theta_tilde), total_theta_grad