from jax._src.lib import xla_bridge from jax._src import dispatch from jax.interpreters import mlir from jax.interpreters import xla from jax.experimental.maps import mesh, Mesh FLAGS = flags.FLAGS flags.DEFINE_string( 'jax_test_dut', '', help= 'Describes the device under test in case special consideration is required.' ) flags.DEFINE_integer('num_generated_cases', int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), help='Number of generated cases to test') flags.DEFINE_integer( 'max_cases_sampling_retries', int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')), 'Number of times a failed test sample should be retried. ' 'When an unseen case cannot be generated in this many trials, the ' 'sampling process is terminated.') flags.DEFINE_bool('jax_skip_slow_tests', bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).') flags.DEFINE_string( 'test_targets', '',
'Comma-separated list of platform names specifying which platforms jax ' 'should attempt to initialize. The first platform in the list that is ' 'successfully initialized will be used as the default platform. For ' 'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be ' 'initialized, and the CPU backend will be used unless otherwise specified; ' '--jax_platforms=cpu means that only the CPU backend will be initialized. ' 'By default, jax will try to initialize all available platforms and will ' 'default to GPU or TPU if available, and fallback to CPU otherwise.') flags.DEFINE_bool( 'jax_disable_most_optimizations', bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') flags.DEFINE_integer( 'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0), 'Optional profile version for XLA compilation. ' 'This is meaningful only when XLA is configured to ' 'support the remote compilation profile feature.') def get_compile_options( num_replicas: int, num_partitions: int, device_assignment=None, use_spmd_partitioning: bool = True, use_auto_spmd_partitioning: bool = False, auto_spmd_partitioning_mesh_shape=[], auto_spmd_partitioning_mesh_ids=[]) -> xla_client.CompileOptions: """Returns the compile options to use, as derived from flag values. Args: