os.getenv('JAX_PLATFORM_NAME', '').lower(), 'Deprecated, please use --jax_platforms instead.') flags.DEFINE_string( 'jax_platforms', os.getenv('JAX_PLATFORMS', '').lower(), 'Comma-separated list of platform names specifying which platforms jax ' 'should attempt to initialize. The first platform in the list that is ' 'successfully initialized will be used as the default platform. For ' 'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be ' 'initialized, and the CPU backend will be used unless otherwise specified; ' '--jax_platforms=cpu means that only the CPU backend will be initialized. ' 'By default, jax will try to initialize all available platforms and will ' 'default to GPU or TPU if available, and fallback to CPU otherwise.') flags.DEFINE_bool( 'jax_disable_most_optimizations', bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') def get_compile_options( num_replicas: int, num_partitions: int, device_assignment=None, use_spmd_partitioning: bool = True, use_auto_spmd_partitioning: bool = False, auto_spmd_partitioning_mesh_shape=[], auto_spmd_partitioning_mesh_ids=[]) -> xla_client.CompileOptions: """Returns the compile options to use, as derived from flag values. Args: num_replicas: Number of replicas for which to compile.
'Describes the device under test in case special consideration is required.' ) flags.DEFINE_integer('num_generated_cases', int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), help='Number of generated cases to test') flags.DEFINE_integer( 'max_cases_sampling_retries', int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')), 'Number of times a failed test sample should be retried. ' 'When an unseen case cannot be generated in this many trials, the ' 'sampling process is terminated.') flags.DEFINE_bool('jax_skip_slow_tests', bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).') flags.DEFINE_string( 'test_targets', '', 'Regular expression specifying which tests to run, called via re.search on ' 'the test name. If empty or unspecified, run all tests.') flags.DEFINE_string( 'exclude_test_targets', '', 'Regular expression specifying which tests NOT to run, called via re.search ' 'on the test name. If empty or unspecified, run all tests.') EPS = 1e-4 def _dtype(x):