Beispiel #1
0
def get_compile_options(
    num_replicas: int,
    num_partitions: int,
    device_assignment=None,
    use_spmd_partitioning: bool = True,
) -> xla_client.CompileOptions:
    """Returns the compile options to use, as derived from flag values.

  Args:
    num_replicas: Number of replicas for which to compile.
    num_partitions: Number of partitions for which to compile.
    device_assignment: Optional tuple of integers indicating the assignment of
      logical replicas to physical devices (default inherited from
      xla_client.CompileOptions). Must be consistent with `num_replicas` and
      `num_partitions`.
    use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
      partitioning in XLA.
  """
    compile_options = xla_client.CompileOptions()
    compile_options.num_replicas = num_replicas
    compile_options.num_partitions = num_partitions
    build_options = compile_options.executable_build_options
    build_options.use_spmd_partitioning = use_spmd_partitioning
    if device_assignment is not None:
        logging.vlog(
            2,
            'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
            num_replicas, num_partitions, device_assignment)
        device_assignment = np.array(device_assignment)

        # Allow 1D device assignment if num_partitions is 1.
        if (device_assignment.ndim == 1) and (num_partitions == 1):
            device_assignment = device_assignment[:, None]

        if num_replicas != device_assignment.shape[0]:
            msg = 'device_assignment does not match num_replicas: {} vs {}.'
            raise ValueError(msg.format(device_assignment, num_replicas))

        if num_partitions != device_assignment.shape[1]:
            msg = 'device_assignment does not match num_partitions: {} vs {}.'
            raise ValueError(msg.format(device_assignment, num_partitions))

        device_assignment = xla_client.DeviceAssignment.create(
            device_assignment)
        assert device_assignment.replica_count() == num_replicas
        assert device_assignment.computation_count() == num_partitions
        compile_options.device_assignment = device_assignment

    debug_options = compile_options.executable_build_options.debug_options
    if jax._src.lib.cuda_path is not None:
        debug_options.xla_gpu_cuda_data_dir = jax._src.lib.cuda_path

    if FLAGS.jax_disable_most_optimizations:

        debug_options.xla_backend_optimization_level = 0
        debug_options.xla_llvm_disable_expensive_passes = True
        debug_options.xla_test_all_input_layouts = False

    return compile_options
Beispiel #2
0
def get_compile_options(
        num_replicas: int,
        num_partitions: int,
        device_assignment=None,
        use_spmd_partitioning: bool = True,
        use_auto_spmd_partitioning: bool = False,
        auto_spmd_partitioning_mesh_shape=[],
        auto_spmd_partitioning_mesh_ids=[]) -> xla_client.CompileOptions:
    """Returns the compile options to use, as derived from flag values.

  Args:
    num_replicas: Number of replicas for which to compile.
    num_partitions: Number of partitions for which to compile.
    device_assignment: Optional ndarray of jax devices indicating the assignment
      of logical replicas to physical devices (default inherited from
      xla_client.CompileOptions). Must be consistent with `num_replicas` and
      `num_partitions`.
    use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
      partitioning in XLA.
    use_auto_spmd_partitioning: boolean indicating whether to automatically
      generate XLA shardings for SPMD partitioner.
    auto_spmd_partitioning_mesh_shape: device mesh shape used to create
      auto_spmd_partitioning search space.
    auto_spmd_partitioning_mesh_ids: device ids used to create
      auto_spmd_partitioning search space.
  """
    compile_options = xla_client.CompileOptions()
    compile_options.num_replicas = num_replicas
    compile_options.num_partitions = num_partitions
    build_options = compile_options.executable_build_options
    build_options.use_spmd_partitioning = use_spmd_partitioning
    build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
    if use_auto_spmd_partitioning:
        build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
        build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
    if device_assignment is not None:
        logging.vlog(
            2,
            'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
            num_replicas, num_partitions, device_assignment)
        device_assignment = np.array(device_assignment)

        # Allow 1D device assignment if num_partitions is 1.
        if (device_assignment.ndim == 1) and (num_partitions == 1):
            device_assignment = device_assignment[:, None]

        if num_replicas != device_assignment.shape[0]:
            msg = 'device_assignment does not match num_replicas: {} vs {}.'
            raise ValueError(msg.format(device_assignment, num_replicas))

        if num_partitions != device_assignment.shape[1]:
            msg = 'device_assignment does not match num_partitions: {} vs {}.'
            raise ValueError(msg.format(device_assignment, num_partitions))

        if device_assignment.dtype == object:
            device_assignment = np.vectorize(lambda d: d.id,
                                             otypes=[int])(device_assignment)
        device_assignment = xla_client.DeviceAssignment.create(
            device_assignment)
        assert device_assignment.replica_count() == num_replicas
        assert device_assignment.computation_count() == num_partitions
        compile_options.device_assignment = device_assignment

    debug_options = compile_options.executable_build_options.debug_options
    if jax._src.lib.cuda_path is not None:
        debug_options.xla_gpu_cuda_data_dir = jax._src.lib.cuda_path

    if FLAGS.jax_disable_most_optimizations:

        debug_options.xla_backend_optimization_level = 0
        debug_options.xla_llvm_disable_expensive_passes = True
        debug_options.xla_test_all_input_layouts = False

    if jax._src.lib.xla_extension_version >= 68:
        compile_options.profile_version = FLAGS.jax_xla_profile_version
    return compile_options