Пример #1
0
 def reset_parameters_lecun(self, param_init=0.1):
     """Initialize parameters with lecun style.."""
     logger.info('===== Initialize %s with lecun style =====' %
                 self.__class__.__name__)
     for conv_layer in [
             self.pointwise_conv1, self.pointwise_conv2, self.depthwise_conv
     ]:
         for n, p in conv_layer.named_parameters():
             init_with_lecun_normal(n, p, param_init)
Пример #2
0
    def __init__(self,
                 kdim,
                 qdim,
                 adim,
                 atype,
                 n_heads,
                 init_r,
                 bias=True,
                 param_init='',
                 conv1d=False,
                 conv_kernel_size=5):

        super().__init__()

        assert conv_kernel_size % 2 == 1, "Kernel size should be odd for 'same' conv."
        self.key = None
        self.mask = None

        self.atype = atype
        assert adim % n_heads == 0
        self.d_k = adim // n_heads
        self.n_heads = n_heads
        self.scale = math.sqrt(adim)

        if atype == 'add':
            self.w_key = nn.Linear(kdim, adim)
            self.v = nn.Linear(adim, n_heads, bias=False)
            self.w_query = nn.Linear(qdim, adim, bias=False)
        elif atype == 'scaled_dot':
            self.w_key = nn.Linear(kdim, adim, bias=bias)
            self.w_query = nn.Linear(qdim, adim, bias=bias)
        else:
            raise NotImplementedError(atype)

        self.r = nn.Parameter(torch.Tensor([init_r]))
        logger.info('init_r is initialized with %d' % init_r)

        self.conv1d = None
        if conv1d:
            self.conv1d = nn.Conv1d(kdim,
                                    kdim,
                                    conv_kernel_size,
                                    padding=(conv_kernel_size - 1) // 2)
            # NOTE: lookahead is introduced
            for n, p in self.conv1d.named_parameters():
                init_with_lecun_normal(n, p, 0.1)

        if atype == 'add':
            self.v = nn.utils.weight_norm(self.v, name='weight', dim=0)
            # initialization
            self.v.weight_g.data = torch.Tensor([1 / adim]).sqrt()
        elif atype == 'scaled_dot':
            if param_init == 'xavier_uniform':
                self.reset_parameters_xavier_uniform(bias)
Пример #3
0
 def reset_parameters(self, param_init):
     """Initialize parameters with lecun style."""
     logger.info('===== Initialize %s with lecun style =====' %
                 self.__class__.__name__)
     for n, p in self.named_parameters():
         init_with_lecun_normal(n, p, param_init)