def svd_fun(tensor): # Permute the indices of tensor into something closer to the output tensor = jnp.einsum(f"{init_str}->{left_free+right_free}", tensor) # Flatten both sides of our tensor to give a single matrix left_shape = tensor.shape[:len(left_free)] right_shape = tensor.shape[len(left_free):] left_size = jnp.prod(left_shape) right_size = jnp.prod(right_shape) matrix = tensor.reshape((left_size, right_size)) # Get SVD and format so that left_mat@diag(svs)@right_mat = matrix left_mat, sv_vec, right_mat = stable_svd(matrix) # Fold singular values into left/right matrices left_mat, right_mat = apply_sv(left_mat, right_mat, sv_vec) # Reshape the matrices to make them proper tensors left_tensor = left_mat.reshape(left_shape + sv_vec.shape) right_tensor = right_mat.reshape(sv_vec.shape + right_shape) # Move the new bond indices into the correct order left_tensor = jnp.einsum(f"{left_free+bond_char}->{left_str}", left_tensor) right_tensor = jnp.einsum(f"{bond_char+right_free}->{right_str}", right_tensor) return out_fun(left_tensor, right_tensor, sv_vec)
def mnist(flatten: bool = False, one_hot_encoding: bool = False, data_dir: str = os.path.join("..", "datasets", "mnist")): path: Path = Path(data_dir) downloaded_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=path, with_info=True) mnist_data: Dict[str, Dict[str, np.array]] = tfds.as_numpy(downloaded_data) train_data, valid_data = mnist_data.get("train"), mnist_data.get("test") input_shape: Tuple[int, ...] = info.features["image"].shape train_images, train_labels = tensor.asarray( train_data.get("image"), dtype=tensor.float32), tensor.asarray( train_data.get("label"), dtype=tensor.float32).reshape(-1, 1) valid_images, valid_labels = tensor.asarray( valid_data.get("image"), dtype=tensor.float32), tensor.asarray( valid_data.get("label"), dtype=tensor.float32).reshape(-1, 1) if flatten: train_images = train_images.reshape(-1, tensor.prod(list(input_shape))) valid_images = valid_images.reshape(-1, tensor.prod(list(input_shape))) if one_hot_encoding: train_labels = tensor.asarray(pd.get_dummies(train_labels), dtype=tensor.float32) valid_labels = tensor.asarray(pd.get_dummies(valid_labels), dtype=tensor.float32) return (train_images, train_labels), (valid_images, valid_labels)
def test_discrete_barycenter_grid(self, lse_mode, debiased, epsilon): """Tests the discrete barycenters on a 5x5x5 grid. Puts two masses on opposing ends of the hypercube with small noise in between. Check that their W barycenter sits (mostly) at the middle of the hypercube (e.g. index (5x5x5-1)/2) Args: lse_mode: bool, lse or scaling computations. debiased: bool, use (or not) debiasing as proposed in https://arxiv.org/abs/2006.02575 epsilon: float, regularization parameter """ size = jnp.array([5, 5, 5]) grid_3d = grid.Grid(grid_size=size, epsilon=epsilon) a = jnp.ones(size) b = jnp.ones(size) a = a.ravel() b = b.ravel() a = jax.ops.index_update(a, 0, 10000) b = jax.ops.index_update(b, -1, 10000) a = a / jnp.sum(a) b = b / jnp.sum(b) threshold = 1e-2 _, _, bar, errors = db.discrete_barycenter(grid_3d, a=jnp.stack((a, b)), threshold=threshold, lse_mode=lse_mode, debiased=debiased) self.assertGreater(bar[(jnp.prod(size) - 1) // 2], 0.7) self.assertGreater(1, bar[(jnp.prod(size) - 1) // 2]) err = errors[jnp.isfinite(errors)][-1] self.assertGreater(threshold, err)
def test_nFS(): from pOP import nFS as pnFS dim = 2 nC = -1 * np.ones((dim, 1), dtype=np.int32) d = np.zeros(dim, dtype=np.int32) c = np.ones(dim) d2 = np.array([2, 3], dtype=np.int32) nC2Py = np.array([4, 7], dtype=np.int32) nC2 = np.block([[np.arange(4), -1. * np.ones(3)], [np.arange(7)]]).astype(np.int32) n = np.array([10] * dim) N = np.prod(n) z = np.linspace(0, 2. * np.pi, num=n[0]) x = onp.zeros((N, dim)) for k in range(dim): nProd = np.prod(n[k + 1:]) nStack = np.prod(n[0:k]) dark = np.hstack([z] * nProd) x[:, k] = onp.array([dark] * nStack).flatten() c = (2. * np.pi) / (x[-1, :] - x[0, :]) z = (x - x[0, :]) * c - np.pi nfs1 = nFS(x[0, :], x[-1, :], nC, 5) nfs2 = nFS(x[0, :], x[-1, :], nC2, 10) Fc1 = nfs1.H(x.T, d, False) Fc2 = nfs2.H(x.T, d2, False) Fp1 = pnFS(z, 4, d, nC.flatten() * 0.) Fp2 = pnFS(z, 9, d2, nC2Py) assert (np.linalg.norm(Fc1 - Fp1, ord='fro') < 1e-14) assert (np.linalg.norm(Fc2 - Fp2, ord='fro') < 5e-13)
def partial_trace(A, A_label): """ Partial trace on tensor A over repeated labels in A_label """ num_cont = len(A_label) - len(np.unique(A_label)) if num_cont > 0: dup_list = [] for ele in np.unique(A_label): if sum(A_label == ele) > 1: dup_list.append([np.where(A_label == ele)[0]]) cont_ind = np.array(dup_list).reshape(2*num_cont,order='F') free_ind = onp.delete(np.arange(len(A_label)),cont_ind) cont_dim = np.prod(np.array(A.shape)[cont_ind[:num_cont]]) free_dim = np.array(A.shape)[free_ind] B_label = onp.delete(A_label, cont_ind) cont_label = np.unique(A_label[cont_ind]) B = np.zeros(np.prod(free_dim)) A = A.transpose(np.append(free_ind, cont_ind)).reshape(np.prod(free_dim),cont_dim,cont_dim) for ip in range(cont_dim): B = B + A[:,ip,ip] return B.reshape(free_dim), B_label, cont_label else: return A, A_label, []
def eval_polynomial( x: jnp.ndarray, coeff_a: float, coeff_b: float, mul_coeffs: jnp.ndarray, sub_coeffs: jnp.ndarray, ) -> jnp.ndarray: """Evaluate polynomial. Evaluate the polynomial corresponding to the rational equation (coeff_b * x - coeff_a) + sum_i mul_coeffs[i]/(x-sub_coeffs[i]) at x. Args: x: (n,) coeff_a: Scalar coeff_b: Scalar mul_coeffs: (n,) numpy array of multiplicative coefficients sub_coeffs: (n,) numpy array of subtractive coefficients Returns: Values of polynomial at x (same shape as x). """ result = 0. x = jnp.reshape(x, [-1, 1]) for i in range(mul_coeffs.size): coeffs_not_i = (np.arange(mul_coeffs.size) != i) result += (mul_coeffs[i] * jnp.prod( x - jnp.reshape(sub_coeffs[coeffs_not_i], [1, -1]), axis=-1)) result = jnp.reshape(result, [-1]) result -= (jnp.reshape(coeff_b * x - coeff_a, [-1]) * jnp.reshape( jnp.prod(x - jnp.reshape(sub_coeffs, [1, -1]), axis=-1), [-1])) return jnp.reshape(result, [-1])
def upsample_posterior(x, b, log_diag_cov, repeats): """ Posterior of N(x|Az + b, Sigma) where A is an upsample matrix""" assert x.shape == b.shape assert x.shape == log_diag_cov.shape assert x.ndim == 3 xmb = x - b one_over_diag_cov = jnp.exp(-log_diag_cov) # Compute the diagonal of the riemannian metric. This is the diagonal of A^T Sigma^{-1} A hr, wr, cr = repeats; assert cr == 1 # Haven't tested cr != 1 Hx, Wx, C = x.shape H, W = Hx//hr, Wx//wr rm_diag = one_over_diag_cov.reshape((H, hr, W, wr, C)).transpose((0, 2, 4, 1, 3)).reshape((H, W, C, hr*wr)).sum(axis=-1) # Compute the mean of z z_mean = upsample_pseudo_inverse(xmb*one_over_diag_cov, (2, 2, 1))/rm_diag*(hr*wr) x_proj = upsample(repeats, z_mean)*one_over_diag_cov dim_x = jnp.prod(x.shape) dim_z = jnp.prod(z_mean.shape) # Compute the manifold error term log_hx = -0.5*jnp.sum(xmb*(xmb*one_over_diag_cov - x_proj)) log_hx -= 0.5*jnp.sum(jnp.log(rm_diag)) log_hx -= 0.5*log_diag_cov.sum() log_hx -= 0.5*(dim_x - dim_z)*jnp.log(2*jnp.pi) # return z_mean, log_hx, rm_diag, x_proj return z_mean, log_hx, rm_diag
def gaussian_potential(x: jnp.ndarray, mean: Union[float, jnp.ndarray] = 0., prec: Union[float, jnp.ndarray] = None, sqrt_prec: Union[float, jnp.ndarray] = None, det_prec: float = None) -> Union[float, jnp.ndarray]: # sqrt_prec such that prec = sqrt_prec @ sqrt_prec.T d = x.shape[-1] if prec is None and sqrt_prec is None: prec = 1. if isinstance(prec, float): prec = jnp.ones(d) * prec if isinstance(sqrt_prec, float): sqrt_prec = jnp.ones(d) * sqrt_prec if det_prec is None: if prec is not None and prec.ndim < 2: det_prec = jnp.prod(prec) elif sqrt_prec is not None and sqrt_prec.ndim < 2: det_prec = jnp.prod(sqrt_prec)**2 if det_prec is None: # full precision matrix given but no det - computing without norm constant neg_log_z = 0 warn( 'gaussian_potential queried with non-diagonal prec (or sqrt-prec) but no det_prec given' ' -> executing without normalising constant term') else: neg_log_z = (d * jnp.log(2 * jnp.pi) - jnp.log(det_prec)) / 2 if x.ndim == 1 and sqrt_prec is None: # Single vals value (not vectorised) if prec is None: out_val = _mv_gaussian_potential_diag(x, mean, 1.) elif prec.ndim < 2: out_val = _mv_gaussian_potential_diag(x, mean, prec) else: out_val = _mv_gaussian_potential(x, mean, prec) else: # Multiple vals values (vectorised) if prec is not None and sqrt_prec is None: if prec.ndim < 2: sqrt_prec = jnp.sqrt(prec) else: sqrt_prec = jnp.linalg.cholesky(prec) warn( 'vectorised gaussian_potential queried with prec rather than sqrt_prec' '-> executing using Cholesky decomp') if sqrt_prec is None: out_val = _mv_gaussian_potential_diag(x, mean, 1.) elif sqrt_prec.ndim < 2: out_val = _mv_gaussian_potential_diag(x, mean, sqrt_prec**2) else: out_val = _vectorised_gaussian_potential(x, mean, sqrt_prec) return out_val + neg_log_z
def reduce( total: jnp.ndarray, count: tp.Optional[jnp.ndarray], values: jnp.ndarray, reduction: Reduction, sample_weight: tp.Optional[np.ndarray], dtype: jnp.dtype, ) -> tp.Tuple[jnp.ndarray, jnp.ndarray, tp.Optional[jnp.ndarray]]: if sample_weight is not None: sample_weight = sample_weight.astype(dtype) # Update dimensions of weights to match with values if possible. # values, _, sample_weight = tf_losses_utils.squeeze_or_expand_dimensions( # values, sample_weight=sample_weight # ) try: # Broadcast weights if possible. sample_weight = jnp.broadcast_to(sample_weight, values.shape) except ValueError: # Reduce values to same ndim as weight array ndim = values.ndim weight_ndim = sample_weight.ndim if reduction == Reduction.SUM: values = jnp.sum(values, axis=list(range(weight_ndim, ndim))) else: values = jnp.mean(values, axis=list(range(weight_ndim, ndim))) values = values * sample_weight value_sum = jnp.sum(values) total += value_sum # Exit early if the reduction doesn't have a denominator. if reduction == Reduction.SUM: num_values = None # Update `count` for reductions that require a denominator. elif reduction == Reduction.SUM_OVER_BATCH_SIZE: num_values = jnp.prod(values.shape).astype(dtype) else: if sample_weight is None: num_values = jnp.prod(jnp.array(values.shape)).astype(dtype) else: num_values = jnp.sum(sample_weight) if count is not None and num_values is not None: count += num_values if reduction == Reduction.SUM: value = total else: value = total / count return value, total, count
def _hessianopt(x, f): _, hvp = jax.linearize(jax.grad(f), x) hvp = jax.jit(hvp) n = np.prod(x.shape) idxs = np.arange(vsize, n, vsize) basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape) splitbasis = np.split(basis, idxs) vhvp = jax.vmap(hvp) vhvp = jax.jit(vhvp) return np.concatenate([vhvp(b) for b in splitbasis]).reshape(x.shape + x.shape)
def prob_fn(sample: Array, mu: Array, sigma: Array, action_spec): # Support scalar and vector `sigma`. If vector, mu.shape==sigma.shape. mu = mu_activation(mu) sigma = sigma_activation(sigma) # Compute pdf for multivariate gaussian. d = mu.shape[-1] det = jnp.prod(sigma**2, axis=-1) z = ((2 * jnp.pi)**(0.5 * d)) * (det**0.5) exp = jnp.exp(-0.5 * jnp.sum( ((mu - inv_transform(sample, action_spec)) / sigma)**2, axis=-1)) det_jacobian = jnp.prod(jnp.clip(1 - sample**2, 0., 1.) + 1e-6) return exp / (z * det_jacobian)
def mvn_kl(mu_0, sigma_0, mu_1, sigma_1): logdet_sigma_1 = jnp.prod(jnp.array(jnp.linalg.slogdet(sigma_1))) logdet_sigma_0 = jnp.prod(jnp.array(jnp.linalg.slogdet(sigma_0))) term_1 = 0.5 * (logdet_sigma_1 - logdet_sigma_0) # I wonder if there's a more efficient way? mu_outer = jnp.outer(mu_0 - mu_1, mu_0 - mu_1) inside_term = mu_outer + sigma_0 - sigma_1 solved = jnp.linalg.solve(sigma_1, inside_term) term_2 = 0.5 * jnp.trace(solved) return term_1 + term_2
def testPermutationArray(self, dtype, shape): key = random.PRNGKey(0) x = np.arange(np.prod(shape)).reshape(shape).astype(dtype) rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2, check_dtypes=True) self.assertFalse(onp.all(perm1 == x)) # seems unlikely! self.assertAllClose(onp.sort(perm1.ravel()), x.ravel(), check_dtypes=False) self.assertArraysAllClose( x, np.arange(np.prod(shape)).reshape(shape).astype(dtype), check_dtypes=True)
def _hessianopt(x, f): _, hvp = jax.linearize(jax.grad(f), x) hvp = jax.jit(hvp) vhvp = jax.vmap(hvp) vhvp = jax.jit(vhvp) basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape) return vhvp(basis).reshape(x.shape + x.shape)
def __call__(self, inputs, context=None, reverse=False): axes = tuple(i for i in range(len(inputs.shape) - 1)) def dd_mean_initializer(key, shape): """Data-dependent init for mu""" nonlocal inputs x_mean = np.mean(inputs, axis=axes, keepdims=True) return -x_mean def dd_stddev_initializer(key, shape): """Data-dependent init for sigma""" nonlocal inputs x_var = np.mean(inputs**2, axis=axes, keepdims=True) var = self.scale / (np.sqrt(x_var) + self.eps) return var shape = (1, ) * len(axes) + (inputs.shape[-1], ) mu = self.param("actnorm_mean", dd_mean_initializer, shape) sigma = self.param("actnorm_stddev", dd_stddev_initializer, shape) logsigma = np.log(np.abs(sigma)) log_det_jacobian = np.prod(np.array( inputs.shape[1:-1])) * np.sum(logsigma) if reverse: outputs = inputs / (sigma + self.eps) - mu log_det_jacobian = -log_det_jacobian else: outputs = sigma * (inputs + mu) log_det_jacobian = log_det_jacobian return outputs, log_det_jacobian
def testThreadsafeIndexing(self): # NOTE(skye): I picked these values to be big enough to cause interesting # execution overlap, but small enough to not use too much memory. YMMV. shape = (8, 8000, 1000) if jax.device_count() < shape[0]: raise SkipTest(f"requires {shape[0]} devices") x = np.arange(np.prod(shape)).reshape(shape) sharded_x = pmap(lambda x: x)(x) num_threads = 10 futures = [] expected = [] with ThreadPoolExecutor(max_workers=num_threads) as executor: for i in range(num_threads): idx = i % shape[0] # Mix together different kinds of indices if i % 2 == 0: idx = slice(idx, idx + 1) # Use the "kwarg trick" to work around late-binding closures. See # https://docs.python-guide.org/writing/gotchas/#late-binding-closures. futures.append(executor.submit( lambda idx=idx: [sharded_x[idx] for _ in range(10)][0])) expected.append(x[idx]) actual = [f.result() for f in futures] self.assertAllClose(actual, expected, check_dtypes=False)
def maxandargmax(x, axis=axis): if axis is None: axes = tuple(range(x.ndim)) else: axes = tuple(int(ax) for ax in axis) max_res = jnp.max(x, axis) # NumPy does not support multiple axes for argmax; this is a # work-around keep_axes = jnp.array( [i for i in range(x.ndim) if i not in axes], dtype="int64" ) # Not-reduced axes in front transposed_x = jnp.transpose( x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) ) kept_shape = transposed_x.shape[: len(keep_axes)] reduced_shape = transposed_x.shape[len(keep_axes) :] # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 # Otherwise reshape would complain citing float arg new_shape = kept_shape + ( jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), ) reshaped_x = transposed_x.reshape(new_shape) max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") return max_res, max_idx_res
def generate_sample_grid(theta_mean, theta_std, n): """ Create a meshgrid of n ** n_dim samples, tiling [theta_mean[i] - 5 * theta_std[i], theta_mean[i] + 5 * theta_std] into n portions. Also returns the volume element. Parameters ---------- theta_mean, theta_std : ndarray (n_dim) Returns ------- theta_samples : ndarray (nobj, n_dim) vol_element: scalar Volume element """ n_components = theta_mean.size xs = [ np.linspace( theta_mean[i] - 5 * theta_std[i], theta_mean[i] + 5 * theta_std[i], n, ) for i in range(n_components) ] mxs = np.meshgrid(*xs) orshape = mxs[0].shape mxsf = np.vstack([i.ravel() for i in mxs]).T dxs = np.vstack([np.diff(xs[i])[i] for i in range(n_components)]) vol_element = np.prod(dxs) theta_samples = np.vstack(mxsf) return theta_samples, vol_element
def init_state(self, a_shape, rng): # uses random as a hack to support vmap # we should find a non-hack approach to initializing the state dim_a = jnp.prod(a_shape) # np.int32 a_opt = 0.0 * jax.random.uniform( rng, shape=(self.n_steps, dim_a)) # [n_steps, dim_a] return a_opt
def tmrca_sf(t: np.ndarray, y: np.ndarray, n: int) -> np.ndarray: """The survival function of the TMRCA at each time point Args: t: time grid (including zero and infinity) y: effective population size in each epoch n: number of sampled haplotypes """ # epoch durations s = np.diff(t) logu = -s / y logu = np.concatenate((np.array([0]), logu)) # the A_2j are the product of this matrix # NOTE: using letter "l" as a variable name to match text l = onp.arange(2, n + 1)[:, onp.newaxis] # noqa: E741 with onp.errstate(divide='ignore'): A2_terms = l * (l - 1) / (l * (l - 1) - l.T * (l.T - 1)) onp.fill_diagonal(A2_terms, 1) A2 = np.prod(A2_terms, axis=0) binom_vec = l * (l - 1) / 2 result = np.zeros(len(t)) result = index_update(result, index[:-1], np.squeeze(A2[np.newaxis, :] @ np.exp(np.cumsum(logu[np.newaxis, :-1], axis=1)) ** binom_vec)) assert np.all(np.isfinite(result)) return result
def _triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): # pylint: disable=redefined-outer-name """Scipy solve does not broadcast, so we must do so explicitly.""" del name if JAX_MODE: # But JAX uses XLA, which can do a batched solve. matrix = matrix + np.zeros(rhs.shape[:-2] + (1, 1), dtype=matrix.dtype) rhs = rhs + np.zeros(matrix.shape[:-2] + (1, 1), dtype=rhs.dtype) return scipy_linalg.solve_triangular(matrix, rhs, lower=lower, trans='C' if adjoint else 'N') try: bcast = onp.broadcast(matrix[..., :1], rhs) except ValueError as e: raise ValueError( 'Error with inputs shaped `matrix`={}, rhs={}:\n{}'.format( matrix.shape, rhs.shape, str(e))) dim = matrix.shape[-1] matrix = onp.broadcast_to(matrix, bcast.shape[:-1] + (dim, )) rhs = onp.broadcast_to(rhs, bcast.shape) nbatch = int(np.prod(matrix.shape[:-2])) flat_mat = matrix.reshape(nbatch, dim, dim) flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1]) result = np.empty(flat_rhs.shape) if np.size(result): # ValueError: On entry to STRTRS parameter number 7 had an illegal value. for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)): result[i] = scipy_linalg.solve_triangular( mat, rh, lower=lower, trans='C' if adjoint else 'N') return result.reshape(*rhs.shape)
def get_bins_and_bincounts(samples, normalized=False): """take in samples, create a common set of bins, and compute the counts count(x in bin) for each bin and each sample x. Parameters ------------ samples : np.array of shape (n,) or shape (k, n). - If shape (n,): interpreted as a set of n scalar-valued samples. - If shape (k, n): interpreted as k sets of n scalar-valued samples. Returns -------- probabilities : bins : """ nr_samples = np.prod(samples.shape) nr_bins = np.log2(nr_samples) nr_bins = int(max(nr_bins, 5)) lims = [np.min(samples), np.max(samples)] bins = np.linspace(*lims, num=nr_bins) if samples.ndim == 2: out = np.asarray([ np.histogram(x, bins=bins, density=normalized)[0] for x in samples ]) return out, bins elif samples.ndim == 1: return np.histogram(samples, bins=bins, density=normalized)[0], bins else: raise ValueError( f"Input must have shape (n,) or shape (k,n). Instead received shape {samples.shape}" )
def _get_inputs(key, is_conv, same_inputs, input_shape, fn=np.cos): key, split = random.split(key) shape = input_shape if is_conv else (input_shape[0], np.prod(input_shape[1:])) x1 = fn(random.normal(key, shape)) x2 = None if same_inputs else 2 * fn(random.normal(split, shape)) return x1, x2
def __init__(self, input_shape): super(MLPDynamics, self).__init__() self.input_shape = input_shape self.dim = jnp.prod(input_shape[1:]) self.hidden_dim = 100 self.lin1 = hk.Linear(self.hidden_dim) self.lin2 = hk.Linear(self.dim)
def update(self, mpc_state, env, env_state, rng, reward_fn=None, reward_params=None, reward_rng=None): # mpc_state: ([n_steps, dim_a], [n_steps, dim_a, dim_a]) # env: {.step(s, a), .reward(s)} # env_state: [env_shape] np.float32 # rng: rng key for mpc sampling # reward_fn: reward_fn(env, s, params, rng) # reward_params: params for reward function # reward_rng: rng key for reward function stochasticity, e.g. dropout dim_a = jnp.prod(env.a_shape) # np.int32 a_opt, a_cov = mpc_state a_opt = jnp.concatenate([a_opt[1:, :], jnp.expand_dims(jnp.zeros((dim_a,)), axis=0)]) # [n_steps, dim_a] if self.adaptive_covariance: a_cov = jnp.concatenate([a_cov[1:, :], jnp.expand_dims((self.a_std**2)*jnp.eye(dim_a), axis=0)]) def iteration_step(input_, _): a_opt, a_cov, rng = input_ rng_da, rng = jax.random.split(rng) if self.adaptive_covariance: da = jax.vmap(jax.random.multivariate_normal, (0, 0, 0, None), 1)( jax.random.split(rng_da, self.n_steps), # [n_steps], rngs jnp.zeros((self.n_steps, dim_a)), # [n_steps, dim_a] mean a_cov, # [n_steps, dim_a, dim_a] covariance (self.n_samples,), ) # [n_samples, n_steps, dim_a] else: da = self.a_std*jax.random.normal( rng_da, shape=(self.n_samples, self.n_steps, dim_a) ) # [n_samples, n_steps, dim_a] # a: [n_samples, n_steps, dim_a] a = jnp.clip(jnp.expand_dims(a_opt, axis=0) + da, -1.0, 1.0) r = jax.vmap(self.rollout, in_axes=(0, None, None, None, None, None))( a, env, env_state, reward_fn, reward_params, reward_rng ) # [n_samples, n_steps] R = jax.vmap(self.returns)(r) # [n_samples, n_steps], pylint: disable=invalid-name w = jax.vmap(self.weights, 1, 1)(R) # [n_samples, n_steps] da_opt = jax.vmap(jnp.average, (1, None, 1))(da, 0, w) # [n_steps, dim_a] a_opt = jnp.clip(a_opt + da_opt, -1.0, 1.0) # [n_steps, dim_a] if self.adaptive_covariance: a_cov = jax.vmap(jax.vmap(jnp.outer))( da, da ) # [n_samples, n_steps, dim_a, dim_a] a_cov = jax.vmap(jnp.average, (1, None, 1))( a_cov, 0, w ) # a_cov: [n_steps, dim_a, dim_a] # prevent loss of rank when one sample is heavily weighted a_cov = a_cov + jnp.eye(dim_a)*0.00001 return (a_opt, a_cov, rng), None if not self.scan: for _ in range(self.n_iterations): (a_opt, a_cov, rng), _ = iteration_step((a_opt, a_cov, rng), None) else: (a_opt, a_cov, rng), _ = jax.lax.scan( iteration_step, (a_opt, a_cov, rng), None, length=self.n_iterations ) return (a_opt, a_cov)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self): """Tests shuffled mask generation, for 100% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model, self._rng, 1.0) with self.subTest(name='shuffled_full_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values'): self.assertEqual( jnp.count_nonzero(mask['MaskedModule_0']['kernel']), jnp.prod(jnp.array(self._input_dimensions))) with self.subTest(name='shuffled_full_no_input_ablation'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def compute_weighted_cross_entropy(logits, targets, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: `[batch, length, num_classes]` float array. targets: categorical targets `[batch, length]` int array. weights: None or array of shape [batch, length, 1] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) if logits.shape[1] != targets.shape[1]: # Truncate logits. logits = logits[:, :targets.shape[1]] onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = jnp.prod(jnp.asarray(targets.shape)) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None): """Compute weighted accuracy for log probs and targets. Args: logits: `[batch, length, num_classes]` float array. targets: categorical targets `[batch, length]` int array. weights: None or array of shape [batch, length, 1] Returns: Tuple of scalar accuracy and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) if logits.shape[1] != targets.shape[1]: # Truncate logits. logits = logits[:, :targets.shape[1]] acc = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = jnp.prod(jnp.asarray(targets.shape)) if weights is not None: acc = acc * weights normalizing_factor = weights.sum() return acc.sum(), normalizing_factor
def loss_fn(variables): rays = batch["rays"] ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized) if len(ret) not in (1, 2): raise ValueError( "ret should contain either 1 set of output (coarse only), or 2 sets" "of output (coarse as ret[0] and fine as ret[1]).") # The main prediction is always at the end of the ret list. rgb, unused_disp, unused_acc = ret[-1] loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean() psnr = utils.compute_psnr(loss) if len(ret) > 1: # If there are both coarse and fine predictions, we compute the loss for # the coarse prediction (ret[0]) as well. rgb_c, unused_disp_c, unused_acc_c = ret[0] loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean() psnr_c = utils.compute_psnr(loss_c) else: loss_c = 0. psnr_c = 0. def tree_sum_fn(fn): return jax.tree_util.tree_reduce(lambda x, y: x + fn(y), variables, initializer=0) weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) / tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape)))) stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c, psnr_c=psnr_c, weight_l2=weight_l2) return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats
def reduce_prod(x, axis=None, keepdims=False): if axis is None: num_dims = len(x.shape) axis = tuple(range(num_dims)) elif isinstance(axis, list): axis = tuple(axis) return _jnp.prod(x, axis=axis, keepdims=keepdims)