Example #1
0
from absl.testing import absltest
from absl.testing import parameterized

from jax.config import config as jax_config
from jax import random
from jax import jit
import jax.numpy as np

from jax import test_util as jtu

from jax_md import space
from jax_md import minimize
from jax_md import quantity
from jax_md.util import *

jax_config.parse_flags_with_absl()
jax_config.enable_omnistaging()
FLAGS = jax_config.FLAGS

PARTICLE_COUNT = 10
OPTIMIZATION_STEPS = 10
STOCHASTIC_SAMPLES = 10
SPATIAL_DIMENSION = [2, 3]

if FLAGS.jax_enable_x64:
    DTYPE = [f32, f64]
else:
    DTYPE = [f32]


class DynamicsTest(jtu.JaxTestCase):
Example #2
0
import warnings

from absl.testing import absltest
from absl.testing import parameterized

import numpy as np

from jax import api
from jax import dtypes
from jax import numpy as jnp
from jax import ops
from jax import test_util as jtu
from jax import util

from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS

# We disable the whitespace continuation check in this file because otherwise it
# makes the test name formatting unwieldy.
# pylint: disable=bad-continuation


ARRAY_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[array\(seq\)\]"
TUPLE_MSG = r"Using a non-tuple sequence for multidimensional indexing is not allowed.*arr\[tuple\(seq\)\]"


float_dtypes = jtu.dtypes.floating
default_dtypes = float_dtypes + jtu.dtypes.integer
all_dtypes = default_dtypes + jtu.dtypes.boolean
Example #3
0
from __future__ import print_function
import operator
from collections import namedtuple
from absl import flags
from jax.api import eval_shape
from jax.api import jacobian
from jax.api import jvp
from jax.api import vjp
from jax.config import config
import jax.numpy as np
from jax.tree_util import tree_multimap
from jax.tree_util import tree_reduce
from neural_tangents.utils import flags as internal_flags
from neural_tangents.utils.utils import get_namedtuple

config.parse_flags_with_absl()  # NOTE(schsam): Is this safe?

FLAGS = flags.FLAGS


def linearize(f, params):
    """Returns a function f_lin, the first order taylor approximation to f.

  Example:
    >>> # Compute the MSE of the first order Taylor series of a function.
    >>> f_lin = linearize(f, params)
    >>> mse = np.mean((f(new_params, x) - f_lin(new_params, x)) ** 2)

  Args:
    f: A function that we would like to linearize. It should have the signature
       f(params, inputs) where params and inputs are `np.ndarray`s and f should