from __future__ import absolute_import from __future__ import division from __future__ import print_function from contextlib import contextmanager import io import warnings import cloudpickle from horovod.common.util import check_extension try: check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH', __file__, 'mpi_lib_v2') except: check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH', __file__, 'mpi_lib', '_mpi_lib') try: from collections.abc import Iterable except ImportError: from collections import Iterable from horovod.torch.compression import Compression from horovod.torch.mpi_ops import allreduce, allreduce_async, allreduce_, allreduce_async_ from horovod.torch.mpi_ops import allgather, allgather_async from horovod.torch.mpi_ops import broadcast, broadcast_async, broadcast_, broadcast_async_ from horovod.torch.mpi_ops import join from horovod.torch.mpi_ops import poll, synchronize
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # pylint: disable=g-short-docstring-punctuation import os import warnings from horovod.common.util import check_extension, gpu_available check_extension('horovod.tensorflow', 'HOROVOD_WITH_TENSORFLOW', __file__, 'mpi_lib') from horovod.tensorflow import elastic from horovod.tensorflow.compression import Compression from horovod.tensorflow.functions import broadcast_object, broadcast_object_fn, broadcast_variables from horovod.tensorflow.mpi_ops import allgather, broadcast, _allreduce, alltoall from horovod.tensorflow.mpi_ops import init, shutdown from horovod.tensorflow.mpi_ops import size, local_size, rank, local_rank, is_homogeneous from horovod.tensorflow.mpi_ops import rank_op, local_rank_op, size_op, local_size_op from horovod.tensorflow.mpi_ops import mpi_threads_supported, mpi_enabled, mpi_built from horovod.tensorflow.mpi_ops import gloo_enabled, gloo_built from horovod.tensorflow.mpi_ops import nccl_built, ddl_built, ccl_built from horovod.tensorflow.mpi_ops import Average, Sum, Adasum from horovod.tensorflow.mpi_ops import handle_average_backwards_compatibility, check_num_rank_power_of_2 from horovod.tensorflow.util import _executing_eagerly, _make_subgraph, _cache from horovod.tensorflow.mpi_ops import join
# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from horovod.common.util import check_extension, split_list check_extension('horovod.mxnet', 'HOROVOD_WITH_MXNET', __file__, 'mpi_lib') from horovod.mxnet.compression import Compression from horovod.mxnet.functions import allgather_object, broadcast_object from horovod.mxnet.mpi_ops import allgather from horovod.mxnet.mpi_ops import allreduce, allreduce_, grouped_allreduce, grouped_allreduce_ from horovod.mxnet.mpi_ops import alltoall from horovod.mxnet.mpi_ops import broadcast, broadcast_ from horovod.mxnet.mpi_ops import init, shutdown from horovod.mxnet.mpi_ops import is_initialized, start_timeline, stop_timeline from horovod.mxnet.mpi_ops import size, local_size, cross_size, rank, local_rank, cross_rank from horovod.mxnet.mpi_ops import mpi_threads_supported, mpi_enabled, mpi_built from horovod.mxnet.mpi_ops import gloo_enabled, gloo_built from horovod.mxnet.mpi_ops import nccl_built, ddl_built, ccl_built, cuda_built, rocm_built import mxnet as mx