Beispiel #1
0
 def aug_dynamics(augmented_state, t, flat_args):
   """Original system augmented with vjp_y, vjp_t and vjp_args."""
   state_len = int(np.floor_divide(
       augmented_state.shape[0] - flat_args.shape[0] - 1, 2))
   y = augmented_state[:state_len]
   adjoint = augmented_state[state_len:2*state_len]
   dy_dt, vjpfun = jax.vjp(flat_func, y, t, flat_args)
   return np.hstack([np.ravel(dy_dt), np.hstack(vjpfun(-adjoint))])
Beispiel #2
0
def largest_elem(xs):
    """Returns a pair of k and the magnitude of the largest element in the list of
  correlators.

  """
    s = xs.size
    k = np.floor_divide(s, 2) * 2 - 1
    return k, np.abs(xs[s - k - 3])
Beispiel #3
0
 def test_integer_div(self):
     x = jnp.array([-4, -3, -1, 0, 1, 3, 6])
     y = np.int32(3)
     self.ConvertAndCompare(jnp.floor_divide, x, y)
     expected = jnp.floor_divide(x, y)
     # Try it with TF 1 as well (#5831)
     with tf.compat.v1.Session() as sess:
         tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y))
         self.assertAllClose(expected, tf1_res)
Beispiel #4
0
def tune_alpha(f, alpha, target=1000, trigger=1e6):
    """Returns a pair of:

  - a suggested alpha,
  - the maximal size N of the inner product matrix that has a non-zero alpha.

  Parameters:
    f: function from alpha => list of correlators.
    alpha: starting value for the search.
    target: if the function has to recompute alpha, it aims for this value.
    trigger: recalculation won't trigger until some element of the list
             returned by f busts out beyond trigger.

  Example call:

  tune_alpha(
    lambda a: single_matrix_correlators(2 * n - 1, a, g, t1, t2),
    alpha=1,
    target=1e6,
    trigger=1e12
  )
  """

    xs = f(alpha)
    done, ys = items_lte_trigger(xs, trigger=trigger)

    if done:
        return alpha, np.floor_divide(xs.size + 1, 2)

    k, elem = largest_elem(ys)
    new_alpha = np.power(np.divide(target * np.power(alpha, k), elem),
                         np.reciprocal(k))

    if new_alpha == 0:
        return alpha, np.floor_divide(ys.size + 1, 2)

    return tune_alpha(f, new_alpha, target=target, trigger=trigger)
Beispiel #5
0
def pad_for_pool(x, n_downsampling):
    problematic_dim = jnp.shape(x)[-2]
    k = jnp.floor_divide(problematic_dim, 2**n_downsampling)
    if problematic_dim % 2**n_downsampling == 0:
        n_pad = 0
    else:
        n_pad = (k + 1) * 2**n_downsampling - problematic_dim
    padding = (n_pad // 2, n_pad // 2)
    paddings = [
        (0, 0),
        padding,
        padding,
        (0, 0),
    ]
    inputs_padded = jnp.pad(x, paddings)
    return inputs_padded, padding
Beispiel #6
0
def floor_divide(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.floor_divide(x1, x2))
Beispiel #7
0
                             lambda x, y, name=None: np.equal(x, y))

erf = utils.copy_docstring(tf.math.erf,
                           lambda x, name=None: scipy_special.erf(x))

erfc = utils.copy_docstring(tf.math.erfc,
                            lambda x, name=None: scipy_special.erfc(x))

exp = utils.copy_docstring(tf.math.exp, lambda x, name=None: np.exp(x))

expm1 = utils.copy_docstring(tf.math.expm1, lambda x, name=None: np.expm1(x))

floor = utils.copy_docstring(tf.math.floor, lambda x, name=None: np.floor(x))

floordiv = utils.copy_docstring(tf.math.floordiv,
                                lambda x, y, name=None: np.floor_divide(x, y))

greater = utils.copy_docstring(tf.math.greater,
                               lambda x, y, name=None: np.greater(x, y))

greater_equal = utils.copy_docstring(
    tf.math.greater_equal, lambda x, y, name=None: np.greater_equal(x, y))

igamma = utils.copy_docstring(
    tf.math.igamma, lambda a, x, name=None: scipy_special.gammainc(a, x))

igammac = utils.copy_docstring(
    tf.math.igammac, lambda a, x, name=None: scipy_special.gammaincc(a, x))

imag = utils.copy_docstring(tf.math.imag,
                            lambda input, name=None: np.imag(input))