def __init__(self, n_in, n_ctx, n_head, attn_dropout=0.0, resid_dropout=0.0, afn='quick_gelu', scale=True, mask=False, zero_out=False, init_scale=1.0, res_scale=1.0, m_attn = 0.25, m_mlp = 1., checkpoint_attn = 0, checkpoint_mlp = 0, attn_func=0, blocks=None, spread=None, encoder_dims=None, prime_len=None): super().__init__() self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head, attn_dropout=attn_dropout, resid_dropout=resid_dropout, scale=scale, mask=mask, zero_out=zero_out, init_scale=init_scale, checkpoint_attn=checkpoint_attn, attn_func=attn_func, blocks=blocks, spread=spread, encoder_dims=encoder_dims, prime_len=prime_len) self.ln_0 = LayerNorm(n_in) self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in), resid_dropout=resid_dropout, afn=afn, zero_out=zero_out, init_scale=init_scale) self.ln_1 = LayerNorm(n_in) self.res_scale = res_scale self.checkpoint_attn = checkpoint_attn self.checkpoint_mlp = checkpoint_mlp self.n_in = n_in self.attn_func = attn_func
class ResAttnBlock(nn.Module): def __init__(self, n_in, n_ctx, n_head, attn_dropout=0.0, resid_dropout=0.0, afn='quick_gelu', scale=True, mask=False, zero_out=False, init_scale=1.0, res_scale=1.0, m_attn = 0.25, m_mlp = 1., checkpoint_attn = 0, checkpoint_mlp = 0, attn_func=0, blocks=None, spread=None, encoder_dims=None, prime_len=None): super().__init__() self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head, attn_dropout=attn_dropout, resid_dropout=resid_dropout, scale=scale, mask=mask, zero_out=zero_out, init_scale=init_scale, checkpoint_attn=checkpoint_attn, attn_func=attn_func, blocks=blocks, spread=spread, encoder_dims=encoder_dims, prime_len=prime_len) self.ln_0 = LayerNorm(n_in) self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in), resid_dropout=resid_dropout, afn=afn, zero_out=zero_out, init_scale=init_scale) self.ln_1 = LayerNorm(n_in) self.res_scale = res_scale self.checkpoint_attn = checkpoint_attn self.checkpoint_mlp = checkpoint_mlp self.n_in = n_in self.attn_func = attn_func def forward(self, x, encoder_kv, sample=False): if sample: a = self.attn(self.ln_0(x), encoder_kv, sample) m = self.mlp(self.ln_1(x + a)) else: if self.attn_func == 6: assert encoder_kv is not None a = checkpoint(lambda _x,_enc_kv,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s), (x,encoder_kv), (*self.attn.parameters(), *self.ln_0.parameters()), self.checkpoint_attn == 3) # 2 recomputes after the projections, and 1 recomputes after head splitting. else: assert encoder_kv is None a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s), (x,), (*self.attn.parameters(), *self.ln_0.parameters()), self.checkpoint_attn == 3) # 2 recomputes after the projections, and 1 recomputes after head splitting. m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,), (*self.mlp.parameters(), *self.ln_1.parameters()), self.checkpoint_mlp == 1) if self.res_scale == 1.0: h = x + a + m else: h = x + self.res_scale * (a + m) return h