Beispiel #1
0
    def forward(self,
                x,
                encoder_kv=None,
                sample=False,
                fp16=False,
                fp16_out=False):
        if fp16:
            x = x.half()

        # Blocks
        for i, l in enumerate(self._attn_mods):
            if self.checkpoint_res == 1 and not sample:
                if l.attn_func == 6:
                    assert encoder_kv is not None
                    f = functools.partial(l, sample=sample)
                    x = checkpoint(f, (x, encoder_kv), l.parameters(), True)
                else:
                    f = functools.partial(l, encoder_kv=None, sample=sample)
                    x = checkpoint(f, (x, ), l.parameters(), True)
            else:
                if l.attn_func == 6:
                    x = l(x, encoder_kv=encoder_kv, sample=sample)
                else:
                    x = l(x, encoder_kv=None, sample=sample)
            if l.attn.record_attn:
                self.ws.append(l.attn.w)
        if not fp16_out:
            x = x.float()
        return x
Beispiel #2
0
 def forward(self, x, encoder_kv, sample=False):
     if sample:
         a = self.attn(self.ln_0(x), encoder_kv, sample)
         m = self.mlp(self.ln_1(x + a))
     else:
         if self.attn_func == 6:
             assert encoder_kv is not None
             a = checkpoint(
                 lambda _x, _enc_kv, _s=sample: self.attn(
                     self.ln_0(_x), _enc_kv, _s), (x, encoder_kv),
                 (*self.attn.parameters(), *self.ln_0.parameters()),
                 self.checkpoint_attn == 3
             )  # 2 recomputes after the projections, and 1 recomputes after head splitting.
         else:
             assert encoder_kv is None
             a = checkpoint(
                 lambda _x, _enc_kv=None, _s=sample: self.attn(
                     self.ln_0(_x), _enc_kv, _s), (x, ),
                 (*self.attn.parameters(), *self.ln_0.parameters()),
                 self.checkpoint_attn == 3
             )  # 2 recomputes after the projections, and 1 recomputes after head splitting.
         m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a, ),
                        (*self.mlp.parameters(), *self.ln_1.parameters()),
                        self.checkpoint_mlp == 1)
     if self.res_scale == 1.0:
         h = x + a + m
     else:
         h = x + self.res_scale * (a + m)
     return h
Beispiel #3
0
 def forward(self, x):
     if self.checkpoint_res == 1:
         for block in self.blocks:
             x = checkpoint(block, (x, ), block.parameters(), True)
         return x
     else:
         return self.model(x)
 def dense_attn(self, query, key, value, sample):
     query = self.split_heads(query)
     key = self.split_heads(key, k=True)
     value = self.split_heads(value)
     if self.checkpoint_attn == 1 and not sample:
         a = checkpoint(lambda q, k, v, s=sample: self._attn(q, k, v, s),
                        (query, key, value), (),
                        True)
     else:
         a = self._attn(query, key, value, sample)
     a = self.merge_heads(a)
     return a
 def forward(self, x, encoder_kv=None, sample=False):
     curr_ctx = x.shape[1]
     x = self.c_attn(x)
     query, key, value, sample = self.qkv(x,
                                          encoder_kv=encoder_kv,
                                          sample=sample)
     if self.checkpoint_attn == 2 and not sample:
         a = checkpoint(lambda q, k, v, s=sample: self.attn(q, k, v, s),
                        (query, key, value), (),
                        True)
     else:
         a = self.attn(query, key, value, sample)
     if a.shape[1] != curr_ctx:
         offset = self._offset(curr_ctx)
         a = a[:, offset:offset + curr_ctx, :].contiguous()
     a = self.c_proj(a)
     return self.resid_dropout(a)