Пример #1
0
 def get_engine(self):
     return APEXEngine(self._device, self._opt_level)
Пример #2
0
 def get_engine(self):
     return APEXEngine(self._device,
                       apex_kwargs=dict(opt_level=self._opt_level))
Пример #3
0
def get_available_engine(
    fp16: bool = False, ddp: bool = False, amp: bool = False, apex: bool = False
) -> "IEngine":
    """Returns available engine based on given arguments.

    Args:
        fp16 (bool): option to use fp16 for training. Default is `False`.
        ddp (bool): option to use DDP for training. Default is `False`.
        amp (bool): option to use APEX for training. Default is `False`.
        apex (bool): option to use APEX for training. Default is `False`.

    Returns:
        IEngine which match requirements.
    """
    from catalyst.engines.torch import (
        DataParallelEngine,
        DeviceEngine,
        DistributedDataParallelEngine,
    )

    if fp16 and not amp and not apex:
        amp = SETTINGS.amp_required or (SETTINGS.amp_required and SETTINGS.apex_required)
        apex = SETTINGS.apex_required and (not SETTINGS.amp_required)

    if amp:
        assert (
            SETTINGS.amp_required
        ), "catalyst[amp] is not available, to install it, run `pip install catalyst[amp]`."
        assert not apex, "Could not use both apex and amp engines"
        from catalyst.engines.amp import (
            AMPEngine,
            DataParallelAMPEngine,
            DistributedDataParallelAMPEngine,
        )

    if apex:
        assert (
            SETTINGS.apex_required
        ), "catalyst[apex] is not available, to install it, run `pip install catalyst[apex]`."
        assert not amp, "Could not use both apex and amp engines"
        from catalyst.engines.apex import (
            APEXEngine,
            DataParallelAPEXEngine,
            DistributedDataParallelAPEXEngine,
        )

    is_multiple_gpus = NUM_CUDA_DEVICES > 1
    if not IS_CUDA_AVAILABLE:
        return DeviceEngine("cpu")
    elif is_multiple_gpus:
        if ddp:
            if amp:
                return DistributedDataParallelAMPEngine()
            elif apex:
                return DistributedDataParallelAPEXEngine()
            else:
                return DistributedDataParallelEngine()
        else:
            if amp:
                return DataParallelAMPEngine()
            elif apex:
                return DataParallelAPEXEngine()
            else:
                return DataParallelEngine()
    else:
        if amp:
            return AMPEngine()
        elif apex:
            return APEXEngine()
        else:
            return DeviceEngine("cuda")