Example #1
0
class Module(object):
    """Mocking module class for name dispatching.

    It will register primitives from :mod:`minpy.array_variant`.

    Parameters
    ----------
    old : dict
        A meta class including info such as name, path, etc.
    name : None, or str
        Second level name if specified.
    name_injector : dict, or dict-like object
        An optional dict provides manual dispatching
    """

    def __init__(self, old, name=None, name_injector={}):
        # Add module itself into global config
        minpy.Config['modules'].append(self)
        self._registry = Registry()
        self._policy = minpy.Config['default_policy']
        self._logger = log.get_logger(old['__name__'])
        self._logger.info('Initialize module: {}.'.format(old['__name__']))
        self._old = old
        self._name_injector = name_injector
        if len(name_injector) != 0:
            self._logger.info('Inject Name Injector {}'.format(name_injector.__name__))
        for vname in variants:
            if name is None:
                modname = 'minpy.array_variants.{}'.format(vname)
            else:
                modname = 'minpy.array_variants.{}.{}'.format(vname, name)
            mod = importlib.import_module(modname)
            self._logger.info('Importing from {}.'.format(modname))
            primitive_wrapper = lambda func, *args, **kwargs:\
                    Primitive(func, variants[vname], *args, **kwargs)
            # Register all primitives of the module.
            before = len(self._registry._reg)
            mod.register_primitives(self._registry, primitive_wrapper)
            self._logger.info('Got {} primitives from {}'.format(
                len(self._registry._reg) - before, modname))
            primitive_getter = lambda name: self._registry.get(name, variants[vname])
            # Define gradients of primitives.
            mod.def_grads(self._registry, primitive_getter)
        self._logger.info('Import {} primitives'.format(
            len(self._registry._reg)))

    def set_policy(self, plc):
        """Set name dispatch policy.

        Parameters
        ----------
        plc
            New policy.
        """
        assert isinstance(
            plc, Policy), 'Need an instance of `minpy.dispatch.policy.Policy`.'
        self._policy = plc

    @property
    def policy(self):
        """Get policy of current module"""
        return self._policy

    def __getattr__(self, name):
        """Fetch attributes from this module.

        If the name is contained in the primitive registry,
        it will return a primitive selector for further name dispatching.

        :param name: Name of attribute.
        :return: Primitive selector.
        :raises AttributeError: Cannot find attribute.
        """
        # Special members for internal use.
        if name == '__registry__':
            return self._registry
        elif name == '__all__':
            return self._old.__all__
        elif self._registry.has_name(name):
            return PrimitiveSelector(name, self._registry, self._policy)
        elif name in self._name_injector:
            return self._name_injector[name]
        elif name in self._old:
            self._logger.info(
                'No entry found for "{}" in registry, fallback.'.format(name))
            return self._old[name]
        else:
            raise AttributeError('Cannot find name "{}".'.format(name))
Example #2
0
class Module(object):
    """Mocking module class for name dispatching.

    It will register primitives from :mod:`minpy.array_variant`.
    """
    def __init__(self, old, name=None):
        # Add module itself into global config
        minpy.Config['modules'].append(self)
        self._registry = Registry()
        self._policy = minpy.Config['default_policy']
        self._logger = log.get_logger(old['__name__'])
        self._logger.info('Initialize module: {}.'.format(old['__name__']))
        self._old = old
        for vname in variants:
            if name is None:
                modname = 'minpy.array_variants.{}'.format(vname)
            else:
                modname = 'minpy.array_variants.{}.{}'.format(vname, name)
            mod = importlib.import_module(modname)
            self._logger.info('Importing from {}.'.format(modname))
            primitive_wrapper = lambda func, *args, **kwargs:\
                    Primitive(func, variants[vname], *args, **kwargs)
            # Register all primitives of the module.
            before = len(self._registry._reg)
            mod.register_primitives(self._registry, primitive_wrapper)
            self._logger.info('Got {} primitives from {}'.format(
                len(self._registry._reg) - before, modname))
            primitive_getter = lambda name: self._registry.get(
                name, variants[vname])
            # Define gradients of primitives.
            mod.def_grads(self._registry, primitive_getter)
        self._logger.info('Import {} primitives'.format(
            len(self._registry._reg)))

    def set_policy(self, plc):
        """Set name dispatch policy.

        :param plc: New policy.
        """
        assert isinstance(
            plc, Policy), 'Need an instance of `minpy.dispatch.policy.Policy`.'
        self._policy = plc

    def __getattr__(self, name):
        """Fetch attributes from this module.

        If the name is contained in the primitive registry,
        it will return a primitive selector for further name dispatching.

        :param name: Name of attribute.
        :return: Primitive selector.
        :raises AttributeError: Cannot find attribute.
        """
        # Special members for internal use.
        if name == '__registry__':
            return self._registry
        elif name == '__all__':
            return self._old.__all__
        elif self._registry.has_name(name):
            return PrimitiveSelector(name, self._registry, self._policy)
        elif name in self._old:
            self._logger.info(
                'No entry found for "{}" in registry, fallback.'.format(name))
            return self._old[name]
        else:
            raise AttributeError('Cannot find name "{}".'.format(name))