Exemplo n.º 1
0
from jax._src import dtypes
from jax._src import profiler
from jax._src import stages
from jax._src import traceback_util
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util

FLAGS = flags.FLAGS

flags.DEFINE_string(
    'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
    help="Path to which HLO/MHLO IR that is emitted by JAX as input to the "
         "compiler should be dumped as text files. Optional. If omitted, JAX "
         "will not dump IR.")


traceback_util.register_exclusion(__file__)

MYPY = False  # Are we currently type checking with mypy?

xe = xc._xla

Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer

XlaExecutable = xc.Executable
Exemplo n.º 2
0
import os
import platform
from typing import Any, List, Sequence, Optional

import iree.compiler
import iree.runtime

from jax._src.config import flags
from jax._src.lib import xla_client
import numpy as np

FLAGS = flags.FLAGS


flags.DEFINE_string(
    'jax_iree_backend', os.getenv('JAX_IREE_BACKEND', 'dylib'),
    'IREE compiler backend to use.')

class IreeDevice:

  def __init__(self, client):
    self.id = 0
    self.host_id = 0
    self.process_index = 0
    self.platform = "iree"
    self.device_kind = "IREE device"
    self.client = client

  def __str__(self) -> str:
    return "IreeDevice"
Exemplo n.º 3
0
try:
  import jax._src.iree as iree  # type: ignore
except (ModuleNotFoundError, ImportError):
  iree = None

traceback_util.register_exclusion(__file__)


XlaBackend = xla_client._xla.Client

FLAGS = flags.FLAGS

# TODO(phawkins): Remove jax_xla_backend.
flags.DEFINE_string(
    'jax_xla_backend', '',
    'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_string(
    'jax_backend_target',
    os.getenv('JAX_BACKEND_TARGET', '').lower(),
    'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
flags.DEFINE_string(
    'jax_platform_name',
    os.getenv('JAX_PLATFORM_NAME', '').lower(),
    'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_string(
    'jax_platforms',
    os.getenv('JAX_PLATFORMS', '').lower(),
    'Comma-separated list of platform names specifying which platforms jax '
    'should attempt to initialize. The first platform in the list that is '
Exemplo n.º 4
0
from jax import core
from jax._src import dtypes as _dtypes
from jax import lax
from jax._src.config import flags, bool_env, config
from jax._src.util import prod, unzip2
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.maps import mesh, Mesh

FLAGS = flags.FLAGS
flags.DEFINE_string(
    'jax_test_dut',
    '',
    help=
    'Describes the device under test in case special consideration is required.'
)

flags.DEFINE_integer('num_generated_cases',
                     int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
                     help='Number of generated cases to test')

flags.DEFINE_integer(
    'max_cases_sampling_retries',
    int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
    'Number of times a failed test sample should be retried. '
    'When an unseen case cannot be generated in this many trials, the '
    'sampling process is terminated.')

flags.DEFINE_bool('jax_skip_slow_tests',
Exemplo n.º 5
0
iree: Optional[Any]

try:
    import jax._src.iree as iree  # type: ignore
except (ModuleNotFoundError, ImportError):
    iree = None

traceback_util.register_exclusion(__file__)

XlaBackend = xla_client._xla.Client

FLAGS = flags.FLAGS

# TODO(phawkins): Remove jax_xla_backend.
flags.DEFINE_string('jax_xla_backend', '',
                    'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_string(
    'jax_backend_target',
    os.getenv('JAX_BACKEND_TARGET', '').lower(),
    'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
flags.DEFINE_string('jax_platform_name',
                    os.getenv('JAX_PLATFORM_NAME', '').lower(),
                    'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_bool(
    'jax_disable_most_optimizations',
    bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
    'Try not to do much optimization work. This can be useful if the cost of '
    'optimization is greater than that of running a less-optimized program.')
flags.DEFINE_integer(
    'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
Exemplo n.º 6
0
                     int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
                     help='Number of generated cases to test')

flags.DEFINE_integer(
    'max_cases_sampling_retries',
    int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
    'Number of times a failed test sample should be retried. '
    'When an unseen case cannot be generated in this many trials, the '
    'sampling process is terminated.')

flags.DEFINE_bool('jax_skip_slow_tests',
                  bool_env('JAX_SKIP_SLOW_TESTS', False),
                  help='Skip tests marked as slow (> 5 sec).')

flags.DEFINE_string(
    'test_targets', '',
    'Regular expression specifying which tests to run, called via re.search on '
    'the test name. If empty or unspecified, run all tests.')
flags.DEFINE_string(
    'exclude_test_targets', '',
    'Regular expression specifying which tests NOT to run, called via re.search '
    'on the test name. If empty or unspecified, run all tests.')

EPS = 1e-4


def _dtype(x):
    return (getattr(x, 'dtype', None)
            or np.dtype(_dtypes.python_scalar_dtypes.get(type(x), None))
            or np.asarray(x).dtype)