Example #1
0
def get_extension_context(ext_name, **kw):
    """Get the context of the specified extension.

    All extension's module must provide `context(**kw)` function.

    Args:
        ext_name (str) : Module path relative to `nnabla_ext`.
        kw (dict) : Additional keyword arguments for context function in a extension module.

    Returns:
        :class:`nnabla.Context`: The current extension context.

    Example:

        .. code-block:: python

            ctx = get_extension_context('cudnn', device_id='0', type_config='half')
            nn.set_default_context(ctx)

    """
    if ext_name == 'cuda.cudnn':
        from nnabla import logger
        logger.warn(
            'Deprecated extension name "cuda.cudnn" passed. Use "cudnn" instead.'
        )
        ext_name = 'cudnn'
    mod = import_extension_module(ext_name)
    return mod.context(**kw)
Example #2
0
def extension_context(extension_name='cpu', **kw):
    """Get the context of the specified extension.

    All extension's module must provide `context(**kw)` function.

    Args:
        extension_name (str) : Module path relative to `nnabla_ext`.
        kw (dict) : Additional keyword arguments for context function in a extension module.

    Returns:
        :class:`nnabla.Context`: The current extension context.

    Note:
        Deprecated. Use :function:`nnabla.ext_utils.get_extension_context` instead.

    Example:

        .. code-block:: python

            ctx = extension_context('cuda.cudnn', device_id=0)
            nn.set_default_context(ctx)

    """
    from nnabla import logger
    logger.warn(
        'Deprecated API. Use `nnabla.ext_util.get_extension_context(ext_name, **kw)`.')
    from nnabla.ext_utils import get_extension_context
    return get_extension_context(extension_name, **kw)
Example #3
0
def lms_scheduler(ctx,
                  use_lms,
                  gpu_memory_size=8 << 30,
                  window_length=12 << 30):
    _check_list = [x.split(":")[0] for x in ctx.backend]
    if "cudnn" not in _check_list and "cuda" not in _check_list:
        logger.warn(
            "ctx passed to scheduler doesn't have cuda/cudnn backend. lms scheduler will not be used."
        )
        use_lms = False

    if use_lms:
        logger.info(
            "[LMS] gpu_memory_limit: {}GB, prefetch_window_length: {}GB".
            format(
                float(gpu_memory_size) / (1 << 30),
                float(window_length) / (1 << 30)))

        # Change array preference so that lms works well.
        # import nnabla_ext.cuda.init as cuda_init
        # cuda_init.prefer_cpu_pinned_array()
        # cuda_init.prefer_cuda_virtual_array()
        #
        from nnabla.ext_utils import get_extension_context
        # from nnabla import set_default_context
        be, tc = ctx.backend[0].split(":")
        # ctx = get_extension_context(be, device_id=ctx.device_id, type_config=tc)
        # set_default_context(ctx)

        cpu_ctx = get_extension_context("cpu", device_id="", type_config=tc)
        return SwapInOutScheduler(cpu_ctx, ctx, gpu_memory_size, window_length)
    else:

        class DummyScheduler(object):
            function_pre_hook = None
            function_post_hook = None
            update_pre_hook = None
            update_post_hook = None

            def start_scheduling(self):
                return None

            def end_scheduling(self):
                return None

            def __enter__(self):
                return self

            def __exit__(self, exc_type, exc_val, exc_tb):
                pass

        return DummyScheduler()
Example #4
0
def undefined_op(func):
    r"""Returns the number of FLOps for undefined operations."""
    logger.warn(f'FLOps of {func.info.type_name} were ignored. Returns 0.')
    return 0