from absl.testing import absltest from absl.testing import parameterized from jax.config import config as jax_config from jax import random from jax import jit import jax.numpy as np from jax import test_util as jtu from jax_md import space from jax_md import minimize from jax_md import quantity from jax_md.util import * jax_config.parse_flags_with_absl() jax_config.enable_omnistaging() FLAGS = jax_config.FLAGS PARTICLE_COUNT = 10 OPTIMIZATION_STEPS = 10 STOCHASTIC_SAMPLES = 10 SPATIAL_DIMENSION = [2, 3] if FLAGS.jax_enable_x64: DTYPE = [f32, f64] else: DTYPE = [f32] class DynamicsTest(jtu.JaxTestCase):
import warnings from absl.testing import absltest from absl.testing import parameterized import numpy as np from jax import api from jax import dtypes from jax import numpy as jnp from jax import ops from jax import test_util as jtu from jax import util from jax.config import config config.parse_flags_with_absl() FLAGS = config.FLAGS # We disable the whitespace continuation check in this file because otherwise it # makes the test name formatting unwieldy. # pylint: disable=bad-continuation ARRAY_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[array\(seq\)\]" TUPLE_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[tuple\(seq\)\]" float_dtypes = jtu.dtypes.floating default_dtypes = float_dtypes + jtu.dtypes.integer all_dtypes = default_dtypes + jtu.dtypes.boolean
from __future__ import print_function import operator from collections import namedtuple from absl import flags from jax.api import eval_shape from jax.api import jacobian from jax.api import jvp from jax.api import vjp from jax.config import config import jax.numpy as np from jax.tree_util import tree_multimap from jax.tree_util import tree_reduce from neural_tangents.utils import flags as internal_flags from neural_tangents.utils.utils import get_namedtuple config.parse_flags_with_absl() # NOTE(schsam): Is this safe? FLAGS = flags.FLAGS def linearize(f, params): """Returns a function f_lin, the first order taylor approximation to f. Example: >>> # Compute the MSE of the first order Taylor series of a function. >>> f_lin = linearize(f, params) >>> mse = np.mean((f(new_params, x) - f_lin(new_params, x)) ** 2) Args: f: A function that we would like to linearize. It should have the signature f(params, inputs) where params and inputs are `np.ndarray`s and f should