コード例 #1
0
ファイル: mpi_ops.py プロジェクト: zyx1213271098/horovod
# Load all the necessary MXNet C types.
import ctypes
import os

import mxnet as mx
from mxnet.base import c_handle_array, c_str, c_str_array, check_call, string_types

from horovod.common.util import check_installed_version, get_ext_suffix
from horovod.common.basics import HorovodBasics as _HorovodBasics

# Check possible symbol not found error from mxnet version mismatch
try:
    _basics = _HorovodBasics(__file__, 'mpi_lib')
except Exception as e:
    check_installed_version('mxnet', mx.__version__, e)
    raise e
else:
    check_installed_version('mxnet', mx.__version__)

# import basic methods
init = _basics.init
shutdown = _basics.shutdown
is_initialized = _basics.is_initialized
start_timeline = _basics.start_timeline
stop_timeline = _basics.stop_timeline
size = _basics.size
local_size = _basics.local_size
cross_size = _basics.cross_size
rank = _basics.rank
local_rank = _basics.local_rank
コード例 #2
0
    Args:
      name: The name of the .so file to load.

    Raises:
      NotFoundError if were not able to load .so file.
    """
    filename = resource_loader.get_path_to_datafile(name)
    library = load_library.load_op_library(filename)
    return library


# Check possible symbol not found error from tensorflow version mismatch
try:
    MPI_LIB = _load_library('mpi_lib' + get_ext_suffix())
except Exception as e:
    check_installed_version('tensorflow', tf.__version__, e)
    raise e
else:
    check_installed_version('tensorflow', tf.__version__)

_basics = _HorovodBasics(__file__, 'mpi_lib')

# import basic methods
init = _basics.init
shutdown = _basics.shutdown
is_initialized = _basics.is_initialized
start_timeline = _basics.start_timeline
stop_timeline = _basics.stop_timeline
size = _basics.size
local_size = _basics.local_size
rank = _basics.rank
コード例 #3
0
# Load all the necessary PyTorch C types.
import torch

import warnings

from horovod.common.basics import HorovodBasics as _HorovodBasics
from horovod.common.exceptions import HorovodInternalError
from horovod.common.util import check_installed_version, get_average_backwards_compatibility_fun, gpu_available, num_rank_is_power_2

from horovod.torch.compression import Compression

# Check possible symbol not found error from pytorch version mismatch
try:
    from horovod.torch import mpi_lib_v2 as mpi_lib
except Exception as e:
    check_installed_version('pytorch', torch.__version__, e)
    raise e
else:
    check_installed_version('pytorch', torch.__version__)

_NULL = ""

_basics = _HorovodBasics(__file__, 'mpi_lib_v2')
# import basic methods
init = _basics.init
is_initialized = _basics.is_initialized
start_timeline = _basics.start_timeline
stop_timeline = _basics.stop_timeline
size = _basics.size
local_size = _basics.local_size
cross_size = _basics.cross_size