def aug_dynamics(augmented_state, t, flat_args): """Original system augmented with vjp_y, vjp_t and vjp_args.""" state_len = int(np.floor_divide( augmented_state.shape[0] - flat_args.shape[0] - 1, 2)) y = augmented_state[:state_len] adjoint = augmented_state[state_len:2*state_len] dy_dt, vjpfun = jax.vjp(flat_func, y, t, flat_args) return np.hstack([np.ravel(dy_dt), np.hstack(vjpfun(-adjoint))])
def largest_elem(xs): """Returns a pair of k and the magnitude of the largest element in the list of correlators. """ s = xs.size k = np.floor_divide(s, 2) * 2 - 1 return k, np.abs(xs[s - k - 3])
def test_integer_div(self): x = jnp.array([-4, -3, -1, 0, 1, 3, 6]) y = np.int32(3) self.ConvertAndCompare(jnp.floor_divide, x, y) expected = jnp.floor_divide(x, y) # Try it with TF 1 as well (#5831) with tf.compat.v1.Session() as sess: tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y)) self.assertAllClose(expected, tf1_res)
def tune_alpha(f, alpha, target=1000, trigger=1e6): """Returns a pair of: - a suggested alpha, - the maximal size N of the inner product matrix that has a non-zero alpha. Parameters: f: function from alpha => list of correlators. alpha: starting value for the search. target: if the function has to recompute alpha, it aims for this value. trigger: recalculation won't trigger until some element of the list returned by f busts out beyond trigger. Example call: tune_alpha( lambda a: single_matrix_correlators(2 * n - 1, a, g, t1, t2), alpha=1, target=1e6, trigger=1e12 ) """ xs = f(alpha) done, ys = items_lte_trigger(xs, trigger=trigger) if done: return alpha, np.floor_divide(xs.size + 1, 2) k, elem = largest_elem(ys) new_alpha = np.power(np.divide(target * np.power(alpha, k), elem), np.reciprocal(k)) if new_alpha == 0: return alpha, np.floor_divide(ys.size + 1, 2) return tune_alpha(f, new_alpha, target=target, trigger=trigger)
def pad_for_pool(x, n_downsampling): problematic_dim = jnp.shape(x)[-2] k = jnp.floor_divide(problematic_dim, 2**n_downsampling) if problematic_dim % 2**n_downsampling == 0: n_pad = 0 else: n_pad = (k + 1) * 2**n_downsampling - problematic_dim padding = (n_pad // 2, n_pad // 2) paddings = [ (0, 0), padding, padding, (0, 0), ] inputs_padded = jnp.pad(x, paddings) return inputs_padded, padding
def floor_divide(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.floor_divide(x1, x2))
lambda x, y, name=None: np.equal(x, y)) erf = utils.copy_docstring(tf.math.erf, lambda x, name=None: scipy_special.erf(x)) erfc = utils.copy_docstring(tf.math.erfc, lambda x, name=None: scipy_special.erfc(x)) exp = utils.copy_docstring(tf.math.exp, lambda x, name=None: np.exp(x)) expm1 = utils.copy_docstring(tf.math.expm1, lambda x, name=None: np.expm1(x)) floor = utils.copy_docstring(tf.math.floor, lambda x, name=None: np.floor(x)) floordiv = utils.copy_docstring(tf.math.floordiv, lambda x, y, name=None: np.floor_divide(x, y)) greater = utils.copy_docstring(tf.math.greater, lambda x, y, name=None: np.greater(x, y)) greater_equal = utils.copy_docstring( tf.math.greater_equal, lambda x, y, name=None: np.greater_equal(x, y)) igamma = utils.copy_docstring( tf.math.igamma, lambda a, x, name=None: scipy_special.gammainc(a, x)) igammac = utils.copy_docstring( tf.math.igammac, lambda a, x, name=None: scipy_special.gammaincc(a, x)) imag = utils.copy_docstring(tf.math.imag, lambda input, name=None: np.imag(input))