示例#1
0
    def __call__(self, *args, **kwargs):

        # TODO(niboshi): Support link hooks for other forward methods.
        hooks = chainer._get_link_hooks()
        if self._n_local_link_hooks > 0:
            hooks = collections.OrderedDict(hooks)
            hooks.update(self.local_link_hooks)
        hooks = hooks.values()  # avoid six for performance

        # Call forward_preprocess hook
        if hooks:
            cb_args = link_hook._ForwardPreprocessCallbackArgs(
                self, 'forward', args, kwargs)
            for hook in hooks:
                hook.forward_preprocess(cb_args)

        # Call the forward function
        # (See #5078) super().__call__ is used when the method is injected by a
        # mixin class. To keep backward compatibility, the injected one is
        # prioritized over forward().
        forward = getattr(super(Link, self), '__call__', None)
        if forward is None:
            forward = self.forward
        out = forward(*args, **kwargs)

        # Call forward_postprocess hook
        if hooks:
            cb_args = link_hook._ForwardPostprocessCallbackArgs(
                self, 'forward', args, kwargs, out)
            for hook in hooks:
                hook.forward_postprocess(cb_args)

        return out
    def __enter__(self) -> 'LinkHook':
        link_hooks = chainer._get_link_hooks()
        if self.name in link_hooks:
            raise KeyError('hook %s already exists' % self.name)

        link_hooks[self.name] = self
        self.added(None)
        return self
示例#3
0
    def __enter__(self):
        link_hooks = chainer._get_link_hooks()
        if self.name in link_hooks:
            raise KeyError('hook %s already exists' % self.name)

        link_hooks[self.name] = self
        self.added(None)
        return self
def delete_linkhook(linkhook, prefix='', logger=None):
    name = prefix + linkhook.name
    link_hooks = chainer._get_link_hooks()
    if name not in link_hooks.keys():
        logger = logger or getLogger(__name__)
        logger.warning('linkhook {} is not registered'.format(name))
        return
    link_hooks[name].deleted(None)
    del link_hooks[name]
示例#5
0
def delete_linkhook(linkhook, prefix='', logger=None):
    name = prefix + linkhook.name
    link_hooks = chainer._get_link_hooks()
    if name not in link_hooks.keys():
        logger = logger or getLogger(__name__)
        logger.warning('linkhook {} is not registered'.format(name))
        return
    link_hooks[name].deleted(None)
    del link_hooks[name]
def add_linkhook(linkhook, prefix='', logger=None):
    link_hooks = chainer._get_link_hooks()
    name = prefix + linkhook.name
    if name in link_hooks:
        logger = logger or getLogger(__name__)
        logger.warning('hook {} already exists, overwrite.'.format(name))
        pass  # skip this case...
        # raise KeyError('hook %s already exists' % name)
    link_hooks[name] = linkhook
    linkhook.added(None)
    return linkhook
示例#7
0
def add_linkhook(linkhook, prefix='', logger=None):
    link_hooks = chainer._get_link_hooks()
    name = prefix + linkhook.name
    if name in link_hooks:
        logger = logger or getLogger(__name__)
        logger.warning('hook {} already exists, overwrite.'.format(name))
        pass  # skip this case...
        # raise KeyError('hook %s already exists' % name)
    link_hooks[name] = linkhook
    linkhook.added(None)
    return linkhook
 def __exit__(self, *_):
     link_hooks = chainer._get_link_hooks()
     link_hooks[self.name].deleted(None)
     del link_hooks[self.name]
示例#9
0
 def __exit__(self, *_):
     link_hooks = chainer._get_link_hooks()
     link_hooks[self.name].deleted(None)
     del link_hooks[self.name]