def update(self, x, y_true, params, averager=None): # Run forward pass. z1, h1, z2, h2 = self.forward(x, params, return_activations=True) # Compute errors for each layer (= gradient of cost w.r.t layer input). e2 = h2 - y_true # gradient through cross entropy loss e1 = d_sigmoid(z1) * (e2 @ (np.abs(self.V2) * np.sign(self.W2.T)) ) # gradient backpropagation # Using these errors, compute gradients of cost w.r.t. parameters. grad_b1 = e1 grad_b2 = e2 grad_W1 = np.outer(x, e1) # np.outer creates a matrix from two vectors grad_W2 = np.outer(h1, e2) # Update parameters. self.b1 -= params['lr'] * grad_b1 self.b2 -= params['lr'] * grad_b2 self.W1 -= params['lr'] * grad_W1 self.W2 -= params['lr'] * grad_W2 averager.add( 'backward_angle', np.rad2deg( utils.angle_between( (np.abs(self.V2) * np.sign(self.W2.T)).flatten(), self.W2.T.flatten()))) return h2
def testDerivativeIsMonotonicWrtX(self): # Check that the loss increases monotonically with |x|. _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs() # This is just to suppress a warning below. d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x)) mask = jnp.isfinite(alpha) & (jnp.abs(d_x) > (300. * jnp.finfo(jnp.float32).eps)) chex.assert_tree_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
def update_embedding_dbd(embedding, grad, vel, gain, lr, iter_num): """Update the embedding using delta-bar-delta.""" gamma = jnp.where(iter_num > _SWITCH_ITER, _FINAL_MOMENTUM, _INIT_MOMENTUM) gain = jnp.where( jnp.sign(vel) != jnp.sign(grad), gain + _INCREASE_GAIN, jnp.maximum(gain * _DAMP_GAIN, _MIN_GAIN)) vel = gamma * vel - lr * gain * grad embedding += vel return embedding, gain, vel
def terngrad_quantize(v: jnp.ndarray, rng: PRNGKey) -> jnp.ndarray: """Terngrad algorithm https://arxiv.org/abs/1705.07878. Args: v: vector to be quantized. rng: jax random key. Returns: Quantized array. """ sigma = jnp.std(v) v = jnp.where(jnp.abs(v) > 2.5 * sigma, 2.5 * sigma * jnp.sign(v), v) return binary_stochastic_quantize(jnp.abs(v), rng, 0., jnp.amax( jnp.abs(v))) * jnp.sign(v)
def cost_fn(x, y, power): """A transport cost in the form |x-y|^p and its derivative.""" delta = x[:, :, np.newaxis] - y[:, np.newaxis, :] if power == 1.0: cost = np.abs(delta) derivative = np.sign(delta) elif power == 2.0: cost = delta**2.0 derivative = 2.0 * delta else: abs_diff = np.abs(delta) cost = abs_diff**power derivative = power * np.sign(delta) * abs_diff**(power - 1.0) return cost, derivative
def get_psd(self, omega): omega = np.atleast_1d(omega) psd0 = self.term.get_psd(omega) arg = 0.5 * self.delta * omega arg += 1e-8 * (np.abs(arg) < 1e-8) * np.sign(arg) sinc = np.sin(arg) / arg return psd0 * sinc**2
def logdamp(move: Array) -> Array: damped = jnp.where( jnp.abs(move) > 1, jnp.log(1 + jnp.abs(move) * 1.72) * jnp.sign(move), move) return damped
def make_noise_sqrt(rng, shape): noise = jax.random.truncated_normal(rng, lower=-2., upper=2., shape=shape) return jax.lax.stop_gradient( jnp.sign(noise) * jnp.sqrt(jnp.abs(noise)))
def update_opt(_, grads, state): x, h = state grad_vec = jnp.reshape(grads, (-1, 1)) # Inputs are scaled by a constant factor. if isinstance(input_scale, numbers.Number): inputs = input_scale * grad_vec # Inputs are raw (unmodified) gradients. elif input_scale == 'raw': inputs = grad_vec # Inputs are the log-scale and sign of the gradient. elif input_scale == 'log1p': scale = jnp.log1p(jnp.abs(grad_vec)) sign = jnp.sign(grad_vec) inputs = jnp.hstack((scale, sign)) else: raise ValueError(f'Invalid input scale {input_scale}.') h_next = cell.batch_apply(rnn_params, inputs, h) outputs = readout_apply(readout_params, h_next) x_next = x + output_scale * jnp.reshape(outputs, x.shape) return (x_next, h_next)
def lqpos(mps): """ Reshapes the (chiL, d, chiR) MPS tensor into a (chiL, d*chiR) matrix, and computes its LQ decomposition, with the phase of L fixed so as to have a non-negative main diagonal. A new right-orthogonal (chiL, d, chiR) MPS tensor (reshaped from Q) is returned along with L. In addition to being phase-adjusted, L is normalized by division with its L2 norm. PARAMETERS ---------- mps (array-like): The (chiL, d, chiR) MPS tensor. RETURNS ------- L, mps_R: A lower-triangular (chiL x chiL) matrix with a non-negative main-diagonal, and a right-orthogonal (chiL, d, chiR) MPS tensor such that mps = L @ mps_R. """ chiL, d, chiR = mps.shape mps_mat = jnp.reshape(mps, (chiL, chiR * d)) mps_mat = jnp.conj(mps_mat.T) Qdag, Ldag = jnp.linalg.qr(mps_mat) Q = jnp.conj(Qdag.T) L = jnp.conj(Ldag.T) phases = jnp.sign(jnp.diag(L)) L = L * phases L = L / jnp.linalg.norm(L) Q = jnp.conj(phases)[:, None] * Q mps_R = Q.reshape(mps.shape) return (L, mps_R)
def init(rng, shape): # Check the shape std = lax.convert_element_type(stddev, dtype) if len(shape) < 2: raise ValueError('The array to initialize must be ' 'at least two-dimensional') # Flatten the input shape with the last dimension remaining # its original shape so it works for conv2d num_rows = 1 for dim in shape[:-1]: num_rows *= dim num_cols = shape[-1] flat_shape = (num_cols, num_rows) if num_rows < num_cols else (num_rows, num_cols) # Generate a random matrix a = random.normal(rng, flat_shape, dtype=dtype) # Compute the qr factorization q, r = np.linalg.qr(a) # Make Q uniform d = np.diag(r) q *= np.sign(d) if num_rows < num_cols: q = np.transpose(q) return std * np.reshape(q, shape)
def interp(x, xp, fp): """ Simple equivalent of np.interp that compute a linear interpolation. We are not doing any checks, so make sure your query points are lying inside the array. TODO: Implement proper interpolation! x, xp, fp need to be 1d arrays """ # First we find the nearest neighbour ind = np.argmin((x - xp) ** 2) # Perform linear interpolation ind = np.clip(ind, 1, len(xp) - 2) xi = xp[ind] # Figure out if we are on the right or the left of nearest s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64) a = (fp[ind + np.copysign(1, s)] - fp[ind]) / ( xp[ind + np.copysign(1, s)] - xp[ind] ) b = fp[ind] - a * xp[ind] return a * x + b
def qrpos(mps): """ Reshapes the (chiL, d, chiR) MPS tensor into a (chiL*d, chiR) matrix, and computes its QR decomposition, with the phase of R fixed so as to have a non-negative main diagonal. A new left-orthogonal (chiL, d, chiR) MPS tensor (reshaped from Q) is returned along with R. In addition to being phase-adjusted, R is normalized by division with its L2 norm. PARAMETERS ---------- mps (array-like): The (chiL, d, chiR) MPS tensor. RETURNS ------- mps_L, R: A left-orthogonal (chiL, d, chiR) MPS tensor, and an upper triangular (chiR x chiR) matrix with a non-negative main diagonal such that mps = mps_L @ R. """ chiL, d, chiR = mps.shape mps_mat = jnp.reshape(mps, (chiL * d, chiR)) Q, R = jnp.linalg.qr(mps_mat) phases = jnp.sign(jnp.diag(R)) Q = Q * phases R = jnp.conj(phases)[:, None] * R R = R / jnp.linalg.norm(R) mps_L = Q.reshape(mps.shape) return (mps_L, R)
def schedule(count): v = init_value if boundaries_and_scales is not None: for threshold, scale in sorted(boundaries_and_scales.items()): indicator = jnp.max([0., jnp.sign(threshold - count)]) v = v * indicator + (1 - indicator) * scale * v return v
def helmholtz(array, k, step=1.0, aspect_ratio=1.0, mask_f=make_mask, mask_f_dual=make_mask_dual): """Finite difference approx of the helmholtz operator in 2D.""" if array.ndim == 2: kernel = np.array([[0, 1, 0], [1, -4 + np.sign(k) * k**2 * step**2, 1], [0, 1, 0]]) else: raise NotImplementedError mask = mask_f(array.shape[0], aspect_ratio) array_masked = np.multiply(array, mask) mask_dual = mask_f_dual(array.shape[0], aspect_ratio) arr2 = np.multiply(array, mask_dual) lhs = array_masked[np.newaxis, np.newaxis, Ellipsis] rhs = kernel[np.newaxis, np.newaxis, Ellipsis] / step**2 result = jax.lax.conv(lhs, rhs, window_strides=(1, ) * array.ndim, padding='SAME') squeezed = np.squeeze(result, axis=(0, 1)) squeezed = np.multiply(squeezed, mask) return squeezed + arr2
def l1_unit_projection(x): """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2. Args: x: An array of size dim x num. Returns: An array of size dim x num, the projection to the unit L1 ball. """ # https://dl.acm.org/citation.cfm?id=1390191 xshape = x.shape if len(x.shape) == 1: x = x.reshape(-1, 1) eshape = x.shape v = jnp.abs(x.reshape((-1, eshape[-1]))) u = jnp.sort(v, axis=0) u = u[::-1, :] # descending arange = (1 + jnp.arange(eshape[0])).reshape((-1, 1)) usum = (jnp.cumsum(u, axis=0) - 1) / arange rho = jnp.max(((u - usum) > 0) * arange - 1, axis=0, keepdims=True) thx = jnp.take_along_axis(usum, rho, axis=0) w = (v - thx).clip(a_min=0) w = jnp.where(jnp.linalg.norm(v, ord=1, axis=0, keepdims=True) > 1, w, v) x = w.reshape(eshape) * jnp.sign(x) return x.reshape(xshape)
def _arcsin(x, do_backprop): if do_backprop: # https://github.com/google/jax/issues/654 x = np.where(np.abs(x) >= 1, np.sign(x), x) else: x = np.clip(x, -1, 1) return np.arcsin(x)
def slogdet(sparse): """Calculate the log(determinant) of a sparse matrix. Based on equation (2.2) of https://arxiv.org/abs/1112.4379 Parameters ---------- sparse : array 3D array of shape (ny, nx, ndiag) of block diagonal elements. Returns ------- tuple Tuple (sign, logdet) such that sign * exp(logdet) is the determinant. If the determinant is zero, logdet = -inf. """ sparse = check_sparse(sparse, square=True) N, _, P = sparse.shape sign = np.product(np.sign(sparse[-1, -1])) logdet = np.sum(np.log(np.abs(sparse[-1, -1]))) # The individual blocks can be calculated in any order so there # should be a better way to express this using lax.map but I # can't get it to work without "concretization" errors. for i in range(N - 1): s, ld = _block_det(sparse, i, N, P) sign *= s logdet += ld return sign, logdet
def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None, targeted=False): """ JAX implementation of the Fast Gradient Method. :param model_fn: a callable that takes an input tensor and returns the model logits. :param x: input tensor. :param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572. :param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2. :param clip_min: (optional) float. Minimum float value for adversarial example components. :param clip_max: (optional) float. Maximum float value for adversarial example components. :param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the target one-hot label. Otherwise, only provide this parameter if you'd like to use true labels when crafting adversarial samples. Otherwise, model predictions are used as labels to avoid the "label leaking" effect (explained in this paper: https://arxiv.org/abs/1611.01236). Default is None. This argument does not have to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]). :param targeted: (optional) bool. Is the attack targeted or untargeted? Untargeted, the default, will try to make the label incorrect. Targeted will instead try to move in the direction of being more like y. :return: a tensor for the adversarial example """ if norm not in [np.inf, 2]: raise ValueError("Norm order must be either np.inf or 2.") if y is None: # Using model predictions as ground truth to avoid label leaking x_labels = np.argmax(model_fn(x), 1) y = one_hot(x_labels, 10) def loss_adv(image, label): pred = model_fn(image[None]) loss = - np.sum(logsoftmax(pred) * label) if targeted: loss = -loss return loss grads_fn = vmap(grad(loss_adv), in_axes=(0, 0), out_axes=0) grads = grads_fn(x, y) axis = list(range(1, len(grads.shape))) avoid_zero_div = 1e-12 if norm == np.inf: perturbation = eps * np.sign(grads) elif norm == 1: raise NotImplementedError("L_1 norm has not been implemented yet.") elif norm == 2: square = np.maximum(avoid_zero_div, np.sum(np.square(grads), axis=axis, keepdims=True)) perturbation = grads / np.sqrt(square) adv_x = x + perturbation # If clipping is needed, reset all values outside of [clip_min, clip_max] if (clip_min is not None) or (clip_max is not None): # We don't currently support one-sided clipping assert clip_min is not None and clip_max is not None adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max) return adv_x
def drive_pytree(params: Params) -> Params: """Runs DRIVE quantization on a given pytree.""" leaves, tree_def = jax.tree_util.tree_flatten(params) new_leaves = [] for leaf in leaves: # this uses the unbiased scale from section 4.2 in DRIVE's paper (Scale = norm2(R(x))**2 / norm1(R(x)) ) new_leaves.append(jnp.sum(jnp.power(leaf, 2)) * jnp.sign(leaf) / jnp.sum(jnp.abs(leaf))) return jax.tree_util.tree_unflatten(tree_def, new_leaves)
def qmult(key, b): """ QMULT Pre-multiply by random orthogonal matrix. QMULT(A) is Q*A where Q is a random real orthogonal matrix from the Haar distribution, of dimension the number of rows in A. Special case: if A is a scalar then QMULT(A) is the same as QMULT(EYE(A)). Called by RANDSVD. Reference: G.W. Stewart, The efficient generation of random orthogonal matrices with an application to condition estimators, SIAM J. Numer. Anal., 17 (1980), 403-409. """ try: n = b.shape[0] a = b.copy() except AttributeError: n = b a = np.eye(n) d = np.zeros(n) for k in range(n - 2, -1, -1): # Generate random Householder transformation. key, subkey = random.split(key) x = random.normal(subkey, (n - k, )) s = np.linalg.norm(x) # Modification to make sign(0) == 1 sgn = np.sign(x[0]) + float(x[0] == 0) s = sgn * s d = index_update(d, k, -sgn) x = index_update(x, 0, x[0] + s) beta = s * x[0] # Apply the transformation to a y = np.dot(x, a[k:n, :]) a = index_update(a, index[k:n, :], a[k:n, :] - np.outer(x, (y / beta))) # Tidy up signs. for i in range(n - 1): a = index_update(a, index[i, :], d[i] * a[i, :]) # Now randomly change the sign (Gaussian dist) a = index_update(a, index[n - 1, :], a[n - 1, :] * np.sign(random.normal(key, ()))) return a
def compute_log_f_alpha(self, posterior_sample, n_i, log_L_i) -> SignedLogParam: # use meta data to compute res = [] for name, func in zip(self.meta['names'], self.meta['funcs']): res.append(func(posterior_sample, n_i, log_L_i).flatten()) res = jnp.concatenate(res) return SignedLogParam(jnp.log(jnp.abs(res)), jnp.sign(res))
def and_mask(update): # Compute the masked gradients for a single parameter tensor mask = jnp.abs(jnp.mean(jnp.sign(update), 0)) >= agreement_threshold mask = mask.astype(jnp.float32) avg_update = jnp.mean(update, 0) mask_t = mask.sum() / mask.size update = mask * avg_update * (1. / (1e-10 + mask_t)) return update
def get_sign2(f, *xyz, args=()): in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz) XYZ = jnp.meshgrid(*xyz) XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) shape = (len(v) for v in xyz) return jnp.sign(f(*(XYZ + args))).reshape(shape)
def value(self, count: JTensor) -> JTensor: p = self.params # Map the step/boundaries to jnp.float32. boundaries = [jnp.array(v, dtype=jnp.float32) for v in p.boundaries] values = [jnp.array(v, dtype=jnp.float32) for v in p.values] count = count.astype(jnp.float32) if not boundaries: assert len(values) == 1 return values[0] v = 0 for i, threshold in enumerate(boundaries): indicator = jnp.maximum(0., jnp.sign(threshold - count)) v = jnp.where(v > 0, v, indicator * values[i]) # Check if step is greater equal to the last value. indicator = jnp.maximum(0., jnp.sign(1 + count - boundaries[-1])) v = jnp.where(v > 0, v, indicator * values[-1]) return v
def scaled_logsumexp(x, log_b, axis=0): """ logsumexp with scaling """ x_max = jnp.amax(log_b + x, axis=axis, keepdims=True) y = jnp.sum(jnp.exp(log_b + x - x_max), axis=axis) sign_y = jnp.sign(y) abs_y = jnp.log(jnp.abs(y)) return abs_y + jnp.squeeze(x_max, axis=axis)
def _von_mises_centered(key, concentration, shape, dtype): # Cutoff from TensorFlow probability # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570) s_cutoff_map = { jnp.dtype(jnp.float16): 1.8e-1, jnp.dtype(jnp.float32): 2e-2, jnp.dtype(jnp.float64): 1.2e-4, } s_cutoff = s_cutoff_map.get(dtype) r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2) rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration) s_exact = (1.0 + rho**2) / (2.0 * rho) s_approximate = 1.0 / concentration s = jnp.where(concentration > s_cutoff, s_exact, s_approximate) def cond_fn(*args): """ check if all are done or reached max number of iterations """ i, _, done, _, _ = args[0] return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done))) def body_fn(*args): i, key, done, _, w = args[0] uni_ukey, uni_vkey, key = random.split(key, 3) u = random.uniform( key=uni_ukey, shape=shape, dtype=concentration.dtype, minval=-1.0, maxval=1.0, ) z = jnp.cos(jnp.pi * u) w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done y = concentration * (s - w) v = random.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype) accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y) return i + 1, key, accept | done, u, w init_done = jnp.zeros(shape, dtype=bool) init_u = jnp.zeros(shape) init_w = jnp.zeros(shape) _, _, done, u, w = lax.while_loop( cond_fun=cond_fn, body_fun=body_fn, init_val=(jnp.array(0), key, init_done, init_u, init_w), ) return jnp.sign(u) * jnp.arccos(w)
def pgd(adv_loss, x_init, epsilon, num_steps, step_size, input_bounds=(0., 1.)): grad_adv_loss = jax.grad(adv_loss) x = x_init for _ in range(num_steps): grad_x = grad_adv_loss(x) x -= jnp.sign(grad_x) * step_size x = jnp.clip(x, x_init - epsilon, x_init + epsilon) x = jnp.clip(x, input_bounds[0], input_bounds[1]) return x
def clip(x, value=jnp.inf): """Clips elements of x to have magnitude less than or equal to value.""" # Guard to short circuit if no value is given. if value == jnp.inf: return x mask = (jnp.abs(x) <= value).astype(jnp.float32) return x * mask + value * (1. - mask) * jnp.sign(x)
def _block_det(sparse, k, N, P): u = sparse[k:k + 1, k + 1:N, 0:P] S = sparse[k + 1:N, k + 1:N, 0:P] v = sparse[k + 1:N, k:k + 1, 0:P] Sinv_v = sparse_dot_sparse(inv(S), v) M = sparse[k, k] - sparse_dot_sparse(u, Sinv_v) sign = np.product(np.sign(M)) logdet = np.sum(np.log(np.abs(M))) return sign, logdet