예제 #1
0
def sync():
    """Helper method to force this module to synchronize with current distributed context.
    This method should be used when distributed context is manually created or destroyed.
    """
    global _model

    for comp_model_cls in registered_computation_models:
        if comp_model_cls == _SerialModel:
            continue
        model = comp_model_cls.create_from_context()
        if model is not None:
            _model = model
            return

    _model = _SerialModel()
예제 #2
0
def sync(temporary: bool = False) -> None:
    """Helper method to force this module to synchronize with current distributed context.
    This method should be used when distributed context is manually created or destroyed.

    Args:
        temporary: If True, distributed model synchronization is done every call of ``idist.get_*`` methods.
            This may have a negative performance impact.
    """
    global _model

    for comp_model_cls in registered_computation_models:
        if comp_model_cls == _SerialModel:
            continue
        model = comp_model_cls.create_from_context()
        if model is not None:
            _set_model(model, temporary=temporary)
            return

    _model = _SerialModel()
예제 #3
0
def finalize():
    """Finalizes distributed configuration. For example, in case of native pytorch distributed configuration,
    it calls ``dist.destroy_process_group()``.
    """
    _model.finalize()
    _set_model(_SerialModel())
예제 #4
0
    "initialize",
    "finalize",
    "show_config",
    "set_local_rank",
    "all_reduce",
    "all_gather",
    "barrier",
    "hostname",
    "has_xla_support",
    "has_native_dist_support",
    "sync",
    "registered_computation_models",
    "one_rank_only",
]

_model = _SerialModel()

_need_to_sync = True


def _sync_model_wrapper(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        if isinstance(_model, _SerialModel) and _need_to_sync:
            sync()
        return func(*args, **kwargs)

    return wrapper


def sync():