コード例 #1
0
ファイル: SHierEncoder.py プロジェクト: hundred06/transformer
    def __init__(self,
                 isize,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 num_head=8,
                 ahsize=None):

        _ahsize = isize if ahsize is None else ahsize

        _fhsize = _ahsize * 4 if fhsize is None else fhsize

        super(SEncoderLayer, self).__init__()

        self.nets = nn.ModuleList([
            EncoderLayer(isize,
                         _fhsize,
                         dropout,
                         attn_drop,
                         num_head,
                         _ahsize,
                         num_sub=2,
                         comb_input=False),
            EncoderLayerBase(isize, _fhsize, dropout, attn_drop, num_head,
                             _ahsize),
            EncoderLayerBase(isize, _fhsize, dropout, attn_drop, num_head,
                             _ahsize)
        ])
        self.combiner = ResidueCombiner(isize, 4, _fhsize)
コード例 #2
0
    def __init__(self,
                 isize,
                 fhsize=None,
                 dropout=0.0,
                 attn_drop=0.0,
                 num_head=8,
                 ahsize=None,
                 num_sub=1,
                 comb_input=True):

        _ahsize = isize if ahsize is None else ahsize

        _fhsize = _ahsize * 4 if fhsize is None else fhsize

        super(EncoderLayer, self).__init__()

        self.nets = nn.ModuleList([
            EncoderLayerBase(isize, _fhsize, dropout, attn_drop, num_head,
                             _ahsize) for i in range(num_sub)
        ])

        self.combiner = ResidueCombiner(isize,
                                        num_sub + 1 if comb_input else num_sub,
                                        _fhsize)

        self.comb_input = comb_input