Exemplo n.º 1
0
    def __call__(self, module: nn.Module):
        """
        Implementation of Mixout (https://arxiv.org/abs/1909.11299).
        Use with:
        >>> mixout_model = model.apply(MixoutWrapper).
        """
        # duplicate all the parameters, making copies of them and freezing them
        module._names = []
        module._params_orig = dict()
        _params_learned = nn.ParameterDict()
        for n, q in list(module.named_parameters(recurse=False)):
            c = q.clone().detach()
            c.requires_grad = False
            module._params_orig[n] = c
            _params_learned[n] = q
            module._names.append(n)
            delattr(module, n)
            setattr(module, n, c)
        if module._names:
            module._params_learned = _params_learned

        self.hook = Hook(self.p)

        module.register_forward_pre_hook(self.hook)
        return module
Exemplo n.º 2
0
def MixoutWrapper(module: nn.Module,
                  p: float = 0.7,
                  exclude: str = 'layer_norm'):
    """
    Implementation of Mixout (https://arxiv.org/abs/1909.11299).
    Use with:
    >>> mixout_model = model.apply(MixoutWrapper).
    """
    # duplicate all the parameters, making copies of them and freezing them
    module._names = []
    module._params_orig = dict()
    _params_learned = nn.ParameterDict()
    exclude = exclude.split(',')
    for n, q in list(module.named_parameters(recurse=False)):
        if any(k in n for k in exclude):
            continue
        c = q.clone().detach()
        c.requires_grad = False
        module._params_orig[n] = c
        _params_learned[n] = q
        module._names.append(n)
        delattr(module, n)
        setattr(module, n, c)
    if module._names:
        module._params_learned = _params_learned

    def mixout(module, n):
        if module.training:
            o = module._params_orig[n]
            mask = (torch.rand_like(o) < p).type_as(o)
            # update 2020-02-
            return (mask * module._params_orig[n] +
                    (1 - mask) * module._params_learned[n] -
                    p * module._params_orig[n]) / (1 - p)
        else:
            return module._params_learned[n].data

    def hook(module, input):
        for n in module._names:
            v = mixout(module, n)
            setattr(module, n, v)

    module.register_forward_pre_hook(hook)
    return module