Beispiel #1
0
    os.getenv('JAX_PLATFORM_NAME', '').lower(),
    'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_string(
    'jax_platforms',
    os.getenv('JAX_PLATFORMS', '').lower(),
    'Comma-separated list of platform names specifying which platforms jax '
    'should attempt to initialize. The first platform in the list that is '
    'successfully initialized will be used as the default platform. For '
    'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be '
    'initialized, and the CPU backend will be used unless otherwise specified; '
    '--jax_platforms=cpu means that only the CPU backend will be initialized. '
    'By default, jax will try to initialize all available platforms and will '
    'default to GPU or TPU if available, and fallback to CPU otherwise.')
flags.DEFINE_bool(
    'jax_disable_most_optimizations',
    bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
    'Try not to do much optimization work. This can be useful if the cost of '
    'optimization is greater than that of running a less-optimized program.')

def get_compile_options(
    num_replicas: int,
    num_partitions: int,
    device_assignment=None,
    use_spmd_partitioning: bool = True,
    use_auto_spmd_partitioning: bool = False,
    auto_spmd_partitioning_mesh_shape=[],
    auto_spmd_partitioning_mesh_ids=[]) -> xla_client.CompileOptions:
  """Returns the compile options to use, as derived from flag values.

  Args:
    num_replicas: Number of replicas for which to compile.
Beispiel #2
0
    'Describes the device under test in case special consideration is required.'
)

flags.DEFINE_integer('num_generated_cases',
                     int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
                     help='Number of generated cases to test')

flags.DEFINE_integer(
    'max_cases_sampling_retries',
    int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
    'Number of times a failed test sample should be retried. '
    'When an unseen case cannot be generated in this many trials, the '
    'sampling process is terminated.')

flags.DEFINE_bool('jax_skip_slow_tests',
                  bool_env('JAX_SKIP_SLOW_TESTS', False),
                  help='Skip tests marked as slow (> 5 sec).')

flags.DEFINE_string(
    'test_targets', '',
    'Regular expression specifying which tests to run, called via re.search on '
    'the test name. If empty or unspecified, run all tests.')
flags.DEFINE_string(
    'exclude_test_targets', '',
    'Regular expression specifying which tests NOT to run, called via re.search '
    'on the test name. If empty or unspecified, run all tests.')

EPS = 1e-4


def _dtype(x):