Пример #1
0
    def setUpClass(cls):
        try:
            # pylint: disable=import-outside-toplevel
            from jax import config

            config.update("jax_enable_x64", True)
        except Exception as err:
            raise unittest.SkipTest("Skipping jax tests.") from err

        dispatch.set_default_backend("jax")
Пример #2
0
def disable_x64():
    """Experimental context manager to temporarily disable X64 mode.

  Usage::

    >>> import jax.numpy as jnp
    >>> with disable_x64():
    ...   print(jnp.arange(10.0).dtype)
    ...
    float32

  See Also
  --------
  jax.experimental.enable_x64 : temporarily enable X64 mode.
  """
    _x64_state = config.FLAGS.jax_enable_x64
    config.update('jax_enable_x64', False)
    try:
        yield
    finally:
        config.update('jax_enable_x64', _x64_state)
Пример #3
0
from jax import config, grad, jacfwd, jacrev, value_and_grad

config.update("jax_enable_x64", True)

import numpy as np
from rdkit import Chem
from scipy.optimize import check_grad, minimize

from timemachine.fe import estimator, free_energy, topology
from timemachine.fe.functional import construct_differentiable_interface, construct_differentiable_interface_fast
from timemachine.ff import Forcefield
from timemachine.lib import LangevinIntegrator, MonteCarloBarostat
from timemachine.md import builders, minimizer
from timemachine.md.barostat.utils import get_bond_list, get_group_indices
from timemachine.parallel.client import CUDAPoolClient
from timemachine.testsystems.relative import hif2a_ligand_pair


def test_absolute_free_energy():

    suppl = Chem.SDMolSupplier("tests/data/ligands_40.sdf", removeHs=False)
    all_mols = [x for x in suppl]
    mol = all_mols[1]

    complex_system, complex_coords, _, _, complex_box, _ = builders.build_protein_system(
        "tests/data/hif2a_nowater_min.pdb")

    # build the water system.
    solvent_system, solvent_coords, solvent_box, _ = builders.build_water_system(
        4.0)
Пример #4
0
"""

import numpy as np

import scipy
from scipy.spatial.distance import pdist, cdist, squareform
from scipy import linalg
from scipy.optimize import minimize

from functools import partial
import inspect

import jax.numpy as jnp
from jax import jit, jacfwd, grad, value_and_grad, config
from jax.scipy.linalg import solve_triangular as jst
config.update('jax_enable_x64', True)

#{{{ use '@timeit' to decorate a function for timing
import time


def timeit(f):
    def timed(*args, **kw):
        ts = time.time()
        num = 30
        for r in range(num):  # calls function 100 times
            result = f(*args, **kw)
        te = time.time()
        print('func: %r took: %2.4f sec for %d runs' %
              (f.__name__, te - ts, num))
        return result
Пример #5
0
from jax import config, nn, random, tree_util
import jax.numpy as jnp

try:
    # jaxns changes the default precision to double precision
    # so here we undo that action
    use_x64 = config.jax_enable_x64

    from jaxns.nested_sampling import NestedSampler as OrigNestedSampler
    from jaxns.plotting import plot_cornerplot, plot_diagnostics
    from jaxns.prior_transforms.common import ContinuousPrior
    from jaxns.prior_transforms.prior_chain import PriorChain, UniformBase
    from jaxns.utils import summary

    config.update("jax_enable_x64", use_x64)
except ImportError as e:
    raise ImportError(
        "To use this module, please install `jaxns` package. It can be"
        " installed with `pip install jaxns`") from e

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import reparam, seed, trace
from numpyro.infer import Predictive
from numpyro.infer.reparam import Reparam
from numpyro.infer.util import _guess_max_plate_nesting, _validate_model, log_density

__all__ = ["NestedSampler"]

Пример #6
0
 def tearDown(self):
   config.update('jax_numpy_rank_promotion', 'warn')
   super(ConvAqtTest, self).tearDown()
Пример #7
0
import jax.abstract_arrays as j_abstract_arrays
import jax.api_util as j_api_util
import jax.core as j_core
import jax.interpreters.partial_eval as ji_partial_eval
import jax.interpreters.xla as ji_xla
import jax.lax.lax as jl_lax
import jax.linear_util as j_linear_util
import jax.tree_util as j_tree_util
import jax.util as j_util
import numba
import numpy
import traceback

import jax.config as j_config

j_config.update("jax_enable_x64", True)  # necessary for asserts


def get_function(function, args=None, kwargs=None, catch_numba=True):
    jaxpr, constants = _get_jax_objects(
        function, tuple() if args is None else args, {} if kwargs is None else kwargs
    )

    print()
    print("jaxpr")
    print(repr(jaxpr))
    print("constants")
    print(repr(constants))

    ast_builder = AstBuilder()