Пример #1
0
    def __init__(self, n_in, n_ctx, n_head,
                 attn_dropout=0.0, resid_dropout=0.0,
                 afn='quick_gelu', scale=True, mask=False,
                 zero_out=False, init_scale=1.0, res_scale=1.0,
                 m_attn = 0.25, m_mlp = 1.,
                 checkpoint_attn = 0, checkpoint_mlp = 0,
                 attn_func=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head,
                                      attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                      scale=scale, mask=mask,
                                      zero_out=zero_out, init_scale=init_scale,
                                      checkpoint_attn=checkpoint_attn,
                                      attn_func=attn_func, blocks=blocks, spread=spread,
                                      encoder_dims=encoder_dims, prime_len=prime_len)
        self.ln_0 = LayerNorm(n_in)
        self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in),
                       resid_dropout=resid_dropout,
                       afn=afn,
                       zero_out=zero_out, init_scale=init_scale)
        self.ln_1 = LayerNorm(n_in)
        self.res_scale = res_scale

        self.checkpoint_attn = checkpoint_attn
        self.checkpoint_mlp = checkpoint_mlp
        self.n_in = n_in
        self.attn_func = attn_func
Пример #2
0
class ResAttnBlock(nn.Module):
    def __init__(self, n_in, n_ctx, n_head,
                 attn_dropout=0.0, resid_dropout=0.0,
                 afn='quick_gelu', scale=True, mask=False,
                 zero_out=False, init_scale=1.0, res_scale=1.0,
                 m_attn = 0.25, m_mlp = 1.,
                 checkpoint_attn = 0, checkpoint_mlp = 0,
                 attn_func=0, blocks=None, spread=None,
                 encoder_dims=None, prime_len=None):
        super().__init__()
        self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head,
                                      attn_dropout=attn_dropout, resid_dropout=resid_dropout,
                                      scale=scale, mask=mask,
                                      zero_out=zero_out, init_scale=init_scale,
                                      checkpoint_attn=checkpoint_attn,
                                      attn_func=attn_func, blocks=blocks, spread=spread,
                                      encoder_dims=encoder_dims, prime_len=prime_len)
        self.ln_0 = LayerNorm(n_in)
        self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in),
                       resid_dropout=resid_dropout,
                       afn=afn,
                       zero_out=zero_out, init_scale=init_scale)
        self.ln_1 = LayerNorm(n_in)
        self.res_scale = res_scale

        self.checkpoint_attn = checkpoint_attn
        self.checkpoint_mlp = checkpoint_mlp
        self.n_in = n_in
        self.attn_func = attn_func

    def forward(self, x, encoder_kv, sample=False):
        if sample:
            a = self.attn(self.ln_0(x), encoder_kv, sample)
            m = self.mlp(self.ln_1(x + a))
        else:
            if self.attn_func == 6:
                assert encoder_kv is not None
                a = checkpoint(lambda _x,_enc_kv,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
                               (x,encoder_kv),
                               (*self.attn.parameters(), *self.ln_0.parameters()),
                               self.checkpoint_attn == 3)  # 2 recomputes after the projections, and 1 recomputes after head splitting.
            else:
                assert encoder_kv is None
                a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
                               (x,),
                               (*self.attn.parameters(), *self.ln_0.parameters()),
                               self.checkpoint_attn == 3)  # 2 recomputes after the projections, and 1 recomputes after head splitting.
            m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,),
                           (*self.mlp.parameters(), *self.ln_1.parameters()),
                           self.checkpoint_mlp == 1)
        if self.res_scale == 1.0:
            h = x + a + m
        else:
            h = x + self.res_scale * (a + m)
        return h