Exemple #1
0
def model_wrapper(cls: T) -> Union[T, Traceable]:
    """
    Wrap the base model (search space). For example,

    .. code-block:: python

        @model_wrapper
        class MyModel(nn.Module):
            ...

    The wrapper serves two purposes:

        1. Capture the init parameters of python class so that it can be re-instantiated in another process.
        2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.

    Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
    But in future, we might enforce ``@model_wrapper`` to be required for base model.
    """
    _check_wrapped(cls)

    import torch.nn as nn
    assert issubclass(cls, nn.Module)

    wrapper = trace(cls)

    class reset_wrapper(wrapper):
        def __init__(self, *args, **kwargs):
            with ModelNamespace():
                super().__init__(*args, **kwargs)

    _copy_class_wrapper_attributes(wrapper, reset_wrapper)
    reset_wrapper.__wrapped__ = wrapper.__wrapped__
    reset_wrapper._nni_model_wrapper = True
    return reset_wrapper
Exemple #2
0
def model_wrapper(cls: T) -> Union[T, Traceable]:
    """
    Wrap the model if you are using pure-python execution engine. For example

    .. code-block:: python

        @model_wrapper
        class MyModel(nn.Module):
            ...

    The wrapper serves two purposes:

        1. Capture the init parameters of python class so that it can be re-instantiated in another process.
        2. Reset uid in ``mutation`` namespace so that each model counts from zero.
           Can be useful in unittest and other multi-model scenarios.
    """
    _check_wrapped(cls)

    import torch.nn as nn
    assert issubclass(cls, nn.Module)

    wrapper = trace(cls)

    class reset_wrapper(wrapper):
        def __init__(self, *args, **kwargs):
            with ModelNamespace():
                super().__init__(*args, **kwargs)

    _copy_class_wrapper_attributes(wrapper, reset_wrapper)
    reset_wrapper.__wrapped__ = wrapper.__wrapped__
    reset_wrapper._nni_model_wrapper = True
    return reset_wrapper
Exemple #3
0
def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
    """
    To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.

    ``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
    graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.

    .. code-block:: python

        @basic_unit
        class PrimitiveOp(nn.Module):
            ...
    """
    _check_wrapped(cls)

    import torch.nn as nn
    assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'

    cls = trace(cls)
    cls._nni_basic_unit = basic_unit_tag

    # HACK: for torch script
    # https://github.com/pytorch/pytorch/pull/45261
    # https://github.com/pytorch/pytorch/issues/54688
    # I'm not sure whether there will be potential issues
    import torch
    cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
    cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
    cls.trace_args = torch.jit.unused(cls.trace_args)
    cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)

    return cls
Exemple #4
0
def serialize_cls(cls):
    """
    To create an serializable class.
    """
    warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
                  'Try to use nni.trace instead.', category=DeprecationWarning)
    return trace(cls)
Exemple #5
0
def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
    """
    To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.

    ``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
    graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.

    Although ``basic_unit`` calls ``trace`` in its implementation, it is not for serialization. Rather, it is meant
    to capture the initialization arguments for mutation. Also, graph execution engine will stop digging into the inner
    modules when it reaches a module that is decorated with ``basic_unit``.

    .. code-block:: python

        @basic_unit
        class PrimitiveOp(nn.Module):
            ...
    """

    # Internal flag. See nni.trace
    nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
    if nni_trace_flag.lower() == 'disable':
        return cls

    if _check_wrapped(cls, 'basic_unit'):
        return cls

    import torch.nn as nn
    assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'  # type: ignore

    cls = trace(cls)
    cls._nni_basic_unit = basic_unit_tag  # type: ignore

    _torchscript_patch(cls)

    return cls
Exemple #6
0
def serialize(cls, *args, **kwargs):
    """
    To create an serializable instance inline without decorator. For example,

    .. code-block:: python

        self.op = serialize(MyCustomOp, hidden_units=128)
    """
    warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
                  'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.',
                  category=DeprecationWarning)
    return trace(cls)(*args, **kwargs)
Exemple #7
0
def model_wrapper(cls: T) -> Union[T, Traceable]:
    """
    Wrap the base model (search space). For example,

    .. code-block:: python

        @model_wrapper
        class MyModel(nn.Module):
            ...

    The wrapper serves two purposes:

    1. Capture the init parameters of python class so that it can be re-instantiated in another process.
    2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.

    Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
    But in future, we might enforce ``@model_wrapper`` to be required for base model.
    """

    # Internal flag. See nni.trace
    nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
    if nni_trace_flag.lower() == 'disable':
        return cls

    if _check_wrapped(cls, 'model_wrapper'):
        return cls

    import torch.nn as nn
    assert issubclass(cls, nn.Module)

    # subclass can still use trace info
    wrapper = trace(cls, inheritable=True)

    class reset_wrapper(wrapper):
        def __init__(self, *args, **kwargs):
            self._model_namespace = ModelNamespace()
            with self._model_namespace:
                super().__init__(*args, **kwargs)

    _copy_class_wrapper_attributes(wrapper, reset_wrapper)
    reset_wrapper.__wrapped__ = getattr(wrapper, '__wrapped__', wrapper)
    reset_wrapper._nni_model_wrapper = True
    reset_wrapper._traced = True

    _torchscript_patch(cls)

    return reset_wrapper