예제 #1
0
from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.maps import mesh, Mesh

FLAGS = flags.FLAGS
flags.DEFINE_string(
    'jax_test_dut',
    '',
    help=
    'Describes the device under test in case special consideration is required.'
)

flags.DEFINE_integer('num_generated_cases',
                     int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
                     help='Number of generated cases to test')

flags.DEFINE_integer(
    'max_cases_sampling_retries',
    int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
    'Number of times a failed test sample should be retried. '
    'When an unseen case cannot be generated in this many trials, the '
    'sampling process is terminated.')

flags.DEFINE_bool('jax_skip_slow_tests',
                  bool_env('JAX_SKIP_SLOW_TESTS', False),
                  help='Skip tests marked as slow (> 5 sec).')

flags.DEFINE_string(
    'test_targets', '',
예제 #2
0
파일: xla_bridge.py 프로젝트: cloudhan/jax
    'Comma-separated list of platform names specifying which platforms jax '
    'should attempt to initialize. The first platform in the list that is '
    'successfully initialized will be used as the default platform. For '
    'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be '
    'initialized, and the CPU backend will be used unless otherwise specified; '
    '--jax_platforms=cpu means that only the CPU backend will be initialized. '
    'By default, jax will try to initialize all available platforms and will '
    'default to GPU or TPU if available, and fallback to CPU otherwise.')
flags.DEFINE_bool(
    'jax_disable_most_optimizations',
    bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
    'Try not to do much optimization work. This can be useful if the cost of '
    'optimization is greater than that of running a less-optimized program.')
flags.DEFINE_integer(
    'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
    'Optional profile version for XLA compilation. '
    'This is meaningful only when XLA is configured to '
    'support the remote compilation profile feature.')


def get_compile_options(
        num_replicas: int,
        num_partitions: int,
        device_assignment=None,
        use_spmd_partitioning: bool = True,
        use_auto_spmd_partitioning: bool = False,
        auto_spmd_partitioning_mesh_shape=[],
        auto_spmd_partitioning_mesh_ids=[]) -> xla_client.CompileOptions:
    """Returns the compile options to use, as derived from flag values.

  Args: