示例#1
0
    def update(self, x, y_true, params, averager=None):
        # Run forward pass.
        z1, h1, z2, h2 = self.forward(x, params, return_activations=True)

        # Compute errors for each layer (= gradient of cost w.r.t layer input).
        e2 = h2 - y_true  # gradient through cross entropy loss
        e1 = d_sigmoid(z1) * (e2 @ (np.abs(self.V2) * np.sign(self.W2.T))
                              )  # gradient backpropagation

        # Using these errors, compute gradients of cost w.r.t. parameters.
        grad_b1 = e1
        grad_b2 = e2
        grad_W1 = np.outer(x, e1)  # np.outer creates a matrix from two vectors
        grad_W2 = np.outer(h1, e2)

        # Update parameters.
        self.b1 -= params['lr'] * grad_b1
        self.b2 -= params['lr'] * grad_b2
        self.W1 -= params['lr'] * grad_W1
        self.W2 -= params['lr'] * grad_W2

        averager.add(
            'backward_angle',
            np.rad2deg(
                utils.angle_between(
                    (np.abs(self.V2) * np.sign(self.W2.T)).flatten(),
                    self.W2.T.flatten())))

        return h2
示例#2
0
 def testDerivativeIsMonotonicWrtX(self):
     # Check that the loss increases monotonically with |x|.
     _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs()
     # This is just to suppress a warning below.
     d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x))
     mask = jnp.isfinite(alpha) & (jnp.abs(d_x) >
                                   (300. * jnp.finfo(jnp.float32).eps))
     chex.assert_tree_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
示例#3
0
def update_embedding_dbd(embedding, grad, vel, gain, lr, iter_num):
    """Update the embedding using delta-bar-delta."""
    gamma = jnp.where(iter_num > _SWITCH_ITER, _FINAL_MOMENTUM, _INIT_MOMENTUM)
    gain = jnp.where(
        jnp.sign(vel) != jnp.sign(grad), gain + _INCREASE_GAIN,
        jnp.maximum(gain * _DAMP_GAIN, _MIN_GAIN))
    vel = gamma * vel - lr * gain * grad
    embedding += vel
    return embedding, gain, vel
示例#4
0
def terngrad_quantize(v: jnp.ndarray, rng: PRNGKey) -> jnp.ndarray:
  """Terngrad algorithm https://arxiv.org/abs/1705.07878.

  Args:
    v: vector to be quantized.
    rng: jax random key.

  Returns:
    Quantized array.
  """
  sigma = jnp.std(v)
  v = jnp.where(jnp.abs(v) > 2.5 * sigma, 2.5 * sigma * jnp.sign(v), v)
  return binary_stochastic_quantize(jnp.abs(v), rng, 0., jnp.amax(
      jnp.abs(v))) * jnp.sign(v)
示例#5
0
def cost_fn(x, y, power):
    """A transport cost in the form |x-y|^p and its derivative."""
    delta = x[:, :, np.newaxis] - y[:, np.newaxis, :]
    if power == 1.0:
        cost = np.abs(delta)
        derivative = np.sign(delta)
    elif power == 2.0:
        cost = delta**2.0
        derivative = 2.0 * delta
    else:
        abs_diff = np.abs(delta)
        cost = abs_diff**power
        derivative = power * np.sign(delta) * abs_diff**(power - 1.0)
    return cost, derivative
示例#6
0
 def get_psd(self, omega):
     omega = np.atleast_1d(omega)
     psd0 = self.term.get_psd(omega)
     arg = 0.5 * self.delta * omega
     arg += 1e-8 * (np.abs(arg) < 1e-8) * np.sign(arg)
     sinc = np.sin(arg) / arg
     return psd0 * sinc**2
示例#7
0
def logdamp(move: Array) -> Array:

    damped = jnp.where(
        jnp.abs(move) > 1,
        jnp.log(1 + jnp.abs(move) * 1.72) * jnp.sign(move), move)

    return damped
示例#8
0
 def make_noise_sqrt(rng, shape):
     noise = jax.random.truncated_normal(rng,
                                         lower=-2.,
                                         upper=2.,
                                         shape=shape)
     return jax.lax.stop_gradient(
         jnp.sign(noise) * jnp.sqrt(jnp.abs(noise)))
        def update_opt(_, grads, state):
            x, h = state

            grad_vec = jnp.reshape(grads, (-1, 1))

            # Inputs are scaled by a constant factor.
            if isinstance(input_scale, numbers.Number):
                inputs = input_scale * grad_vec

            # Inputs are raw (unmodified) gradients.
            elif input_scale == 'raw':
                inputs = grad_vec

            # Inputs are the log-scale and sign of the gradient.
            elif input_scale == 'log1p':
                scale = jnp.log1p(jnp.abs(grad_vec))
                sign = jnp.sign(grad_vec)
                inputs = jnp.hstack((scale, sign))

            else:
                raise ValueError(f'Invalid input scale {input_scale}.')

            h_next = cell.batch_apply(rnn_params, inputs, h)
            outputs = readout_apply(readout_params, h_next)
            x_next = x + output_scale * jnp.reshape(outputs, x.shape)
            return (x_next, h_next)
示例#10
0
def lqpos(mps):
    """
    Reshapes the (chiL, d, chiR) MPS tensor into a (chiL, d*chiR) matrix,
    and computes its LQ decomposition, with the phase of L fixed so as to
    have a non-negative main diagonal. A new right-orthogonal
    (chiL, d, chiR) MPS tensor (reshaped from Q) is returned along with
    L.
    In addition to being phase-adjusted, L is normalized by division with
    its L2 norm.

    PARAMETERS
    ----------
    mps (array-like): The (chiL, d, chiR) MPS tensor.

    RETURNS
    -------
    L, mps_R:  A lower-triangular (chiL x chiL) matrix with a non-negative
               main-diagonal, and a right-orthogonal (chiL, d, chiR) MPS
               tensor such that mps = L @ mps_R.
    """
    chiL, d, chiR = mps.shape
    mps_mat = jnp.reshape(mps, (chiL, chiR * d))
    mps_mat = jnp.conj(mps_mat.T)
    Qdag, Ldag = jnp.linalg.qr(mps_mat)
    Q = jnp.conj(Qdag.T)
    L = jnp.conj(Ldag.T)
    phases = jnp.sign(jnp.diag(L))
    L = L * phases
    L = L / jnp.linalg.norm(L)
    Q = jnp.conj(phases)[:, None] * Q
    mps_R = Q.reshape(mps.shape)
    return (L, mps_R)
示例#11
0
    def init(rng, shape):
        # Check the shape
        std = lax.convert_element_type(stddev, dtype)
        if len(shape) < 2:
            raise ValueError('The array to initialize must be '
                             'at least two-dimensional')
        # Flatten the input shape with the last dimension remaining
        # its original shape so it works for conv2d
        num_rows = 1
        for dim in shape[:-1]:
            num_rows *= dim
        num_cols = shape[-1]
        flat_shape = (num_cols,
                      num_rows) if num_rows < num_cols else (num_rows,
                                                             num_cols)

        # Generate a random matrix
        a = random.normal(rng, flat_shape, dtype=dtype)
        # Compute the qr factorization
        q, r = np.linalg.qr(a)
        # Make Q uniform
        d = np.diag(r)
        q *= np.sign(d)
        if num_rows < num_cols:
            q = np.transpose(q)
        return std * np.reshape(q, shape)
示例#12
0
def interp(x, xp, fp):
    """
  Simple equivalent of np.interp that compute a linear interpolation.

  We are not doing any checks, so make sure your query points are lying
  inside the array.

  TODO: Implement proper interpolation!

  x, xp, fp need to be 1d arrays
  """
    # First we find the nearest neighbour
    ind = np.argmin((x - xp) ** 2)

    # Perform linear interpolation
    ind = np.clip(ind, 1, len(xp) - 2)

    xi = xp[ind]
    # Figure out if we are on the right or the left of nearest
    s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)
    a = (fp[ind + np.copysign(1, s)] - fp[ind]) / (
        xp[ind + np.copysign(1, s)] - xp[ind]
    )
    b = fp[ind] - a * xp[ind]
    return a * x + b
示例#13
0
def qrpos(mps):
    """
    Reshapes the (chiL, d, chiR) MPS tensor into a (chiL*d, chiR) matrix,
    and computes its QR decomposition, with the phase of R fixed so as to
    have a non-negative main diagonal. A new left-orthogonal
    (chiL, d, chiR) MPS tensor (reshaped from Q) is returned along with
    R.

    In addition to being phase-adjusted, R is normalized by division with
    its L2 norm.

    PARAMETERS
    ----------
    mps (array-like): The (chiL, d, chiR) MPS tensor.

    RETURNS
    -------
    mps_L, R: A left-orthogonal (chiL, d, chiR) MPS tensor, and an upper
              triangular (chiR x chiR) matrix with a non-negative main
              diagonal such that mps = mps_L @ R.
    """
    chiL, d, chiR = mps.shape
    mps_mat = jnp.reshape(mps, (chiL * d, chiR))
    Q, R = jnp.linalg.qr(mps_mat)
    phases = jnp.sign(jnp.diag(R))
    Q = Q * phases
    R = jnp.conj(phases)[:, None] * R
    R = R / jnp.linalg.norm(R)
    mps_L = Q.reshape(mps.shape)
    return (mps_L, R)
示例#14
0
 def schedule(count):
   v = init_value
   if boundaries_and_scales is not None:
     for threshold, scale in sorted(boundaries_and_scales.items()):
       indicator = jnp.max([0., jnp.sign(threshold - count)])
       v = v * indicator + (1 - indicator) * scale * v
   return v
示例#15
0
def helmholtz(array,
              k,
              step=1.0,
              aspect_ratio=1.0,
              mask_f=make_mask,
              mask_f_dual=make_mask_dual):
    """Finite difference approx of the helmholtz operator in 2D."""
    if array.ndim == 2:
        kernel = np.array([[0, 1, 0], [1, -4 + np.sign(k) * k**2 * step**2, 1],
                           [0, 1, 0]])
    else:
        raise NotImplementedError
    mask = mask_f(array.shape[0], aspect_ratio)
    array_masked = np.multiply(array, mask)
    mask_dual = mask_f_dual(array.shape[0], aspect_ratio)
    arr2 = np.multiply(array, mask_dual)
    lhs = array_masked[np.newaxis, np.newaxis, Ellipsis]
    rhs = kernel[np.newaxis, np.newaxis, Ellipsis] / step**2
    result = jax.lax.conv(lhs,
                          rhs,
                          window_strides=(1, ) * array.ndim,
                          padding='SAME')
    squeezed = np.squeeze(result, axis=(0, 1))
    squeezed = np.multiply(squeezed, mask)
    return squeezed + arr2
示例#16
0
def l1_unit_projection(x):
  """Euclidean projection to L1 unit ball i.e. argmin_{|v|_1<= 1} |x-v|_2.

  Args:
    x: An array of size dim x num.

  Returns:
    An array of size dim x num, the projection to the unit L1 ball.
  """
  # https://dl.acm.org/citation.cfm?id=1390191
  xshape = x.shape
  if len(x.shape) == 1:
    x = x.reshape(-1, 1)
  eshape = x.shape
  v = jnp.abs(x.reshape((-1, eshape[-1])))
  u = jnp.sort(v, axis=0)
  u = u[::-1, :]  # descending
  arange = (1 + jnp.arange(eshape[0])).reshape((-1, 1))
  usum = (jnp.cumsum(u, axis=0) - 1) / arange
  rho = jnp.max(((u - usum) > 0) * arange - 1, axis=0, keepdims=True)
  thx = jnp.take_along_axis(usum, rho, axis=0)
  w = (v - thx).clip(a_min=0)
  w = jnp.where(jnp.linalg.norm(v, ord=1, axis=0, keepdims=True) > 1, w, v)
  x = w.reshape(eshape) * jnp.sign(x)
  return x.reshape(xshape)
示例#17
0
def _arcsin(x, do_backprop):
    if do_backprop:
        # https://github.com/google/jax/issues/654
        x = np.where(np.abs(x) >= 1, np.sign(x), x)
    else:
        x = np.clip(x, -1, 1)
    return np.arcsin(x)
示例#18
0
def slogdet(sparse):
    """Calculate the log(determinant) of a sparse matrix.

    Based on equation (2.2) of https://arxiv.org/abs/1112.4379

    Parameters
    ----------
    sparse : array
        3D array of shape (ny, nx, ndiag) of block diagonal elements.

    Returns
    -------
    tuple
        Tuple (sign, logdet) such that sign * exp(logdet) is the
        determinant. If the determinant is zero, logdet = -inf.
    """
    sparse = check_sparse(sparse, square=True)
    N, _, P = sparse.shape
    sign = np.product(np.sign(sparse[-1, -1]))
    logdet = np.sum(np.log(np.abs(sparse[-1, -1])))
    # The individual blocks can be calculated in any order so there
    # should be a better way to express this using lax.map but I
    # can't get it to work without "concretization" errors.
    for i in range(N - 1):
        s, ld = _block_det(sparse, i, N, P)
        sign *= s
        logdet += ld
    return sign, logdet
def fast_gradient_method(model_fn, x, eps, norm, clip_min=None, clip_max=None, y=None,
	targeted=False):
  """
  JAX implementation of the Fast Gradient Method.
  :param model_fn: a callable that takes an input tensor and returns the model logits.
  :param x: input tensor.
  :param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
  :param norm: Order of the norm (mimics NumPy). Possible values: np.inf or 2.
  :param clip_min: (optional) float. Minimum float value for adversarial example components.
  :param clip_max: (optional) float. Maximum float value for adversarial example components.
  :param y: (optional) Tensor with one-hot true labels. If targeted is true, then provide the
            target one-hot label. Otherwise, only provide this parameter if you'd like to use true
            labels when crafting adversarial samples. Otherwise, model predictions are used
            as labels to avoid the "label leaking" effect (explained in this paper:
            https://arxiv.org/abs/1611.01236). Default is None. This argument does not have
            to be a binary one-hot label (e.g., [0, 1, 0, 0]), it can be floating points values
            that sum up to 1 (e.g., [0.05, 0.85, 0.05, 0.05]).
  :param targeted: (optional) bool. Is the attack targeted or untargeted?
            Untargeted, the default, will try to make the label incorrect.
            Targeted will instead try to move in the direction of being more like y.
  :return: a tensor for the adversarial example
  """
  if norm not in [np.inf, 2]:
    raise ValueError("Norm order must be either np.inf or 2.")

  if y is None:
    # Using model predictions as ground truth to avoid label leaking
    x_labels = np.argmax(model_fn(x), 1)
    y = one_hot(x_labels, 10)

  def loss_adv(image, label):
    pred = model_fn(image[None])
    loss = - np.sum(logsoftmax(pred) * label)
    if targeted:
    	loss = -loss
    return loss

  grads_fn = vmap(grad(loss_adv), in_axes=(0, 0), out_axes=0)
  grads = grads_fn(x, y)

  axis = list(range(1, len(grads.shape)))
  avoid_zero_div = 1e-12
  if norm == np.inf:
    perturbation = eps * np.sign(grads)
  elif norm == 1:
    raise NotImplementedError("L_1 norm has not been implemented yet.")
  elif norm == 2:
    square = np.maximum(avoid_zero_div, np.sum(np.square(grads), axis=axis, keepdims=True))
    perturbation = grads / np.sqrt(square)

  adv_x = x + perturbation

  # If clipping is needed, reset all values outside of [clip_min, clip_max]
  if (clip_min is not None) or (clip_max is not None):
    # We don't currently support one-sided clipping
    assert clip_min is not None and clip_max is not None
    adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)

  return adv_x
示例#20
0
def drive_pytree(params: Params) -> Params:
  """Runs DRIVE quantization on a given pytree."""
  leaves, tree_def = jax.tree_util.tree_flatten(params)
  new_leaves = []
  for leaf in leaves:
    # this uses the unbiased scale from section 4.2 in DRIVE's paper (Scale = norm2(R(x))**2 / norm1(R(x)) )
    new_leaves.append(jnp.sum(jnp.power(leaf, 2)) * jnp.sign(leaf) / jnp.sum(jnp.abs(leaf)))
  return jax.tree_util.tree_unflatten(tree_def, new_leaves)
示例#21
0
def qmult(key, b):
    """
    QMULT  Pre-multiply by random orthogonal matrix.
       QMULT(A) is Q*A where Q is a random real orthogonal matrix from
       the Haar distribution, of dimension the number of rows in A.
       Special case: if A is a scalar then QMULT(A) is the same as
                     QMULT(EYE(A)).
       Called by RANDSVD.
       Reference:
       G.W. Stewart, The efficient generation of random
       orthogonal matrices with an application to condition estimators,
       SIAM J. Numer. Anal., 17 (1980), 403-409.
    """
    try:
        n = b.shape[0]
        a = b.copy()
    except AttributeError:
        n = b
        a = np.eye(n)

    d = np.zeros(n)
    for k in range(n - 2, -1, -1):
        # Generate random Householder transformation.
        key, subkey = random.split(key)
        x = random.normal(subkey, (n - k, ))
        s = np.linalg.norm(x)

        # Modification to make sign(0) == 1
        sgn = np.sign(x[0]) + float(x[0] == 0)
        s = sgn * s
        d = index_update(d, k, -sgn)
        x = index_update(x, 0, x[0] + s)
        beta = s * x[0]

        # Apply the transformation to a
        y = np.dot(x, a[k:n, :])
        a = index_update(a, index[k:n, :], a[k:n, :] - np.outer(x, (y / beta)))

    # Tidy up signs.
    for i in range(n - 1):
        a = index_update(a, index[i, :], d[i] * a[i, :])

    # Now randomly change the sign (Gaussian dist)
    a = index_update(a, index[n - 1, :],
                     a[n - 1, :] * np.sign(random.normal(key, ())))
    return a
示例#22
0
 def compute_log_f_alpha(self, posterior_sample, n_i,
                         log_L_i) -> SignedLogParam:
     # use meta data to compute
     res = []
     for name, func in zip(self.meta['names'], self.meta['funcs']):
         res.append(func(posterior_sample, n_i, log_L_i).flatten())
     res = jnp.concatenate(res)
     return SignedLogParam(jnp.log(jnp.abs(res)), jnp.sign(res))
示例#23
0
 def and_mask(update):
   # Compute the masked gradients for a single parameter tensor
   mask = jnp.abs(jnp.mean(jnp.sign(update), 0)) >= agreement_threshold
   mask = mask.astype(jnp.float32)
   avg_update = jnp.mean(update, 0)
   mask_t = mask.sum() / mask.size
   update = mask * avg_update * (1. / (1e-10 + mask_t))
   return update
示例#24
0
def get_sign2(f, *xyz, args=()):
  in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
  f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
  xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz)
  XYZ = jnp.meshgrid(*xyz)
  XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)
  shape = (len(v) for v in xyz)
  return jnp.sign(f(*(XYZ + args))).reshape(shape)
示例#25
0
 def value(self, count: JTensor) -> JTensor:
     p = self.params
     # Map the step/boundaries to jnp.float32.
     boundaries = [jnp.array(v, dtype=jnp.float32) for v in p.boundaries]
     values = [jnp.array(v, dtype=jnp.float32) for v in p.values]
     count = count.astype(jnp.float32)
     if not boundaries:
         assert len(values) == 1
         return values[0]
     v = 0
     for i, threshold in enumerate(boundaries):
         indicator = jnp.maximum(0., jnp.sign(threshold - count))
         v = jnp.where(v > 0, v, indicator * values[i])
     # Check if step is greater equal to the last value.
     indicator = jnp.maximum(0., jnp.sign(1 + count - boundaries[-1]))
     v = jnp.where(v > 0, v, indicator * values[-1])
     return v
示例#26
0
def scaled_logsumexp(x, log_b, axis=0):
    """ logsumexp with scaling
    """
    x_max = jnp.amax(log_b + x, axis=axis, keepdims=True)
    y = jnp.sum(jnp.exp(log_b + x - x_max), axis=axis)
    sign_y = jnp.sign(y)
    abs_y = jnp.log(jnp.abs(y))
    return abs_y + jnp.squeeze(x_max, axis=axis)
示例#27
0
def _von_mises_centered(key, concentration, shape, dtype):
    # Cutoff from TensorFlow probability
    # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570)
    s_cutoff_map = {
        jnp.dtype(jnp.float16): 1.8e-1,
        jnp.dtype(jnp.float32): 2e-2,
        jnp.dtype(jnp.float64): 1.2e-4,
    }
    s_cutoff = s_cutoff_map.get(dtype)

    r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2)
    rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
    s_exact = (1.0 + rho**2) / (2.0 * rho)

    s_approximate = 1.0 / concentration

    s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)

    def cond_fn(*args):
        """ check if all are done or reached max number of iterations """
        i, _, done, _, _ = args[0]
        return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))

    def body_fn(*args):
        i, key, done, _, w = args[0]
        uni_ukey, uni_vkey, key = random.split(key, 3)

        u = random.uniform(
            key=uni_ukey,
            shape=shape,
            dtype=concentration.dtype,
            minval=-1.0,
            maxval=1.0,
        )
        z = jnp.cos(jnp.pi * u)
        w = jnp.where(done, w,
                      (1.0 + s * z) / (s + z))  # Update where not done

        y = concentration * (s - w)
        v = random.uniform(key=uni_vkey,
                           shape=shape,
                           dtype=concentration.dtype)

        accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)

        return i + 1, key, accept | done, u, w

    init_done = jnp.zeros(shape, dtype=bool)
    init_u = jnp.zeros(shape)
    init_w = jnp.zeros(shape)

    _, _, done, u, w = lax.while_loop(
        cond_fun=cond_fn,
        body_fun=body_fn,
        init_val=(jnp.array(0), key, init_done, init_u, init_w),
    )

    return jnp.sign(u) * jnp.arccos(w)
示例#28
0
def pgd(adv_loss, x_init, epsilon, num_steps, step_size, input_bounds=(0., 1.)):
  grad_adv_loss = jax.grad(adv_loss)
  x = x_init
  for _ in range(num_steps):
    grad_x = grad_adv_loss(x)
    x -= jnp.sign(grad_x) * step_size
    x = jnp.clip(x, x_init - epsilon, x_init + epsilon)
    x = jnp.clip(x, input_bounds[0], input_bounds[1])
  return x
def clip(x, value=jnp.inf):
    """Clips elements of x to have magnitude less than or equal to value."""

    # Guard to short circuit if no value is given.
    if value == jnp.inf:
        return x

    mask = (jnp.abs(x) <= value).astype(jnp.float32)
    return x * mask + value * (1. - mask) * jnp.sign(x)
示例#30
0
def _block_det(sparse, k, N, P):
    u = sparse[k:k + 1, k + 1:N, 0:P]
    S = sparse[k + 1:N, k + 1:N, 0:P]
    v = sparse[k + 1:N, k:k + 1, 0:P]
    Sinv_v = sparse_dot_sparse(inv(S), v)
    M = sparse[k, k] - sparse_dot_sparse(u, Sinv_v)
    sign = np.product(np.sign(M))
    logdet = np.sum(np.log(np.abs(M)))
    return sign, logdet