コード例 #1
0
 def forward(self, h, pe, attn_mask, mem=None, hidden=None):
     new_mem = None
     h = self.lnstart(h)
     if self.rnn:
         x, new_hidden = self.rnn(h, None if hidden is None else hidden)
         ninp = h.shape[-1]
         z = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
         z = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
         x = self.drop(z).sum(dim=-2)
         h = h + x if self.residual else x.float()
     focus, new_mem = None, []
     if self.attn is not None:
         mh = self.lnmem(h)
         h = self.lnmid(h)
         if mem is not None: bigh = torch.cat([mem, mh], dim=0)
         else: bigh = mh
         new_mem = bigh[-len(pe):]
         q, k = h, bigh
         x, focus = checkpoint(self.attn, q, k, bigh, attn_mask)
         x = self.drop(x)
         h = x + h
     if self.ff:
         h, x = self.lnff(h), self.lnxff(h)
         x = checkpoint(self.ff, x)
         x = self.drop(x)
         h = x + h
     return h, new_mem, new_hidden, focus
コード例 #2
0
 def forward(self, x):
     size = (x.shape[2], x.shape[3])
     x = self.conv1(x)
     x = self.bn1(x)
     x = self.relu(x)
     x = self.maxpool(x)
     if self.seg:
         for module in self.layer1._modules.values():
             x = checkpoint(module, x)
         for module in self.layer2._modules.values():
             x = checkpoint(module, x)
         for module in self.layer3._modules.values():
             x = checkpoint(module, x)
         for module in self.layer4._modules.values():
             x = checkpoint(module, x)
         x = self.aspp(x)
         x = nn.Upsample(size, mode='bilinear', align_corners=True)(x)
     else:
         x = self.layer1(x)
         x = self.layer2(x)
         x = self.layer3(x)
         x = self.layer4(x)
         x = self.avgpool(x)
         x = x.view(x.size(0), -1)
         x = self.fc(x)
     return x
コード例 #3
0
ファイル: model.py プロジェクト: CanItRun/sha-rnn
    def forward(self, h, pe, attn_mask, mem=None, hidden=None):
        new_mem = None

        h = self.lnstart(h)

        if self.rnn:
            x, new_hidden = self.rnn(h, None if hidden is None else hidden)
            #x = self.rnn_down(self.drop(x))

            # Trim the end off if the size is different
            ninp = h.shape[-1]
            z = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
            # Divide the hidden size evenly into chunks
            tes=(x.shape[:-1])
            z = x.view(*[*tes, x.shape[-1] // ninp, ninp])
            # Collapse the chunks through summation
            #h = h + self.drop(x).sum(dim=-2)
            x = self.drop(z).sum(dim=-2)
            #x = x + z.sum(dim=-2)

            h = h + x if self.residual else x.float()

        focus, new_mem = None, []
        if self.attn is not None:
            mh = self.lnmem(h)
            h = self.lnmid(h)

            if mem is not None:
                bigh = torch.cat([mem, mh], dim=0)
            else:
                bigh = mh
            new_mem = bigh[-len(pe):]

            q, k = h, bigh

            x, focus = checkpoint(self.attn, q, k, bigh, attn_mask)
            #x, focus = tcheckpoint(self.attn, q, k, bigh, attn_mask)
            x = self.drop(x)
            h = x + h

        if self.ff:
            h, x = self.lnff(h), self.lnxff(h)
            x = checkpoint(self.ff, x)
            #x = tcheckpoint(self.ff, h)
            x = self.drop(x)
            h = x + h

        return h, new_mem, new_hidden, focus
コード例 #4
0
ファイル: dla.py プロジェクト: mymuli/elastic
 def forward(self, x):
     y = []
     x = self.base_layer(x)
     for i in range(6):
         if self.seg:
             x = checkpoint(getattr(self, 'level{}'.format(i)), x)
         else:
             x = getattr(self, 'level{}'.format(i))(x)
         y.append(x)
     if self.return_levels:
         return y
     else:
         x = self.avgpool(x)
         x = self.fc(x)
         x = x.view(x.size(0), -1)
         return x
コード例 #5
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
    ):
        all_hidden_states = ()
        all_attentions = ()
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            # ---- < changed lines > ----
            layer_outputs = checkpoint(layer_module, hidden_states,
                                       attention_mask, head_mask[i],
                                       encoder_hidden_states,
                                       encoder_attention_mask)
            # ---- </changed lines > ----

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1], )

        # Add last layer
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )

        outputs = (hidden_states, )
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states, )
        if output_attentions:
            outputs = outputs + (all_attentions, )
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
コード例 #6
0
    def forward(self, x, head_m=None, cache=None, **kw):
        cfg = self.cfg
        yo = self.get_y_opts(**kw)
        y = x
        attns = () if yo.attn else None
        caches = () if yo.cache else None
        crosses = () if yo.attn and cfg.add_cross else None
        hiddens = () if yo.hidden else None
        for i, lay in enumerate(self.lays):
            if yo.hidden:
                hiddens += (y, )
            h = head_m[i] if head_m is not None else None
            c = cache[i] if cache is not None else None
            if cfg.grad_checkpoint and self.training:
                if yo.cache:
                    yo.cache = False

                def create_forward(x):
                    def forward(*xs):
                        return x(*xs, cache=c, yo=yo)

                    return forward

                ys = checkpoint(create_forward(lay), y, **kw, mask=h)
            else:
                ys = lay(y, **kw, cache=c, mask=h, yo=yo)
            y = ys[0]
            if yo.attn:
                attns += (ys[1], )
                if cfg.add_cross:
                    crosses += (ys[2], )
            if yo.cache:
                caches += (ys[-1], )
        if yo.hidden:
            hiddens += (y, )
        ys = (y, attns, cache, crosses, hiddens)
        return qo.CachesCrosses(*ys) if yo.kw else ys
コード例 #7
0
ファイル: gpt_neox.py プロジェクト: chimjoHA/gpt-neox
    def forward(self, x, mask=None):
        n, device = x.shape[1], x.device

        x = self.token_emb(x)
        x = self.pos_emb(torch.arange(n, device=device)) + x

        def _layer(attn, ff):
            def fn(x):
                x = attn(x) + x
                return ff(x) + x

            return fn

        if self.gradient_checkpointing:
            for (attn, ff) in self.layers:
                layer_fn = _layer(attn, ff)
                x = checkpoint(layer_fn, (x))
        else:
            for (attn, ff) in self.layers:
                layer_fn = _layer(attn, ff)
                x = layer_fn(x)

        x = self.norm(x)
        return self.to_logits(x)
コード例 #8
0
ファイル: model.py プロジェクト: xzm2004260/sha-rnn
    def forward(self,
                x,
                hidden=None,
                mems=None,
                padding_mask=None,
                return_h=True):
        """ Input has shape [seq length, batch] """
        e = self.encoder(x)
        e = self.idrop(e)

        if mems is not None:
            maxmem = self.num_max_positions - len(e)
            mems = [m[-maxmem:] for m in mems]
            #print('maxmem: {}, mem[0] length: {}'.format(maxmem, len(mems[0])))

        total_length = len(x) + (len(mems[0]) if mems else 0)
        #positions = torch.arange(total_length, device=x.device).unsqueeze(-1)
        #pe = self.position_embeddings(positions).expand(total_length, *e.shape[1:])
        pe = self.position_embeddings[-total_length].expand(
            total_length, *e.shape[1:])
        pe = self.idrop(pe)
        pe = self.pe_norm(pe)

        h = e

        if False and mems is not None:
            q = pe[-len(h):] + h
            k = pe[-total_length:-len(h)] + mems[-1]
            x, _ = checkpoint(self.start_attn, q, k, mems[-1])
            x = self.drop(x)
            f = torch.sigmoid(self.start_mix)
            h = h + f * x

        new_hidden = []

        if self.use_qrnn:
            if hidden is None:
                hidden = None
            x, new_h = self.qrnn_out(h, None if hidden is None else hidden[0])
            new_hidden.append(new_h)
            # Trim the end off if the size is different
            x = torch.narrow(x, -1, 0, x.shape[-1] // self.ninp * self.ninp)
            # Divide the hidden size evenly into chunks
            x = x.view(*x.shape[:-1], x.shape[-1] // self.ninp, self.ninp)
            # Collapse the chunks through summation
            h = h + self.idrop(x).sum(dim=-2)

        new_mems = []

        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)),
                                   -float('Inf'),
                                   device=h.device,
                                   dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)
            if mems:
                happy = torch.zeros((len(x), len(mems[0])),
                                    device=h.device,
                                    dtype=h.dtype)
                attn_mask = torch.cat([happy, attn_mask], dim=-1)

        for idx, block in enumerate(self.blocks):
            #h, mem = checkpoint(block, h, pe, attn_mask, mems[idx]) if mems else checkpoint(block, h, pe, attn_mask)
            p = torch.sigmoid(self.position_gates[idx]) * pe
            h, mem, hid = block(
                h, p, attn_mask, mems[idx] if mems else None,
                None if hidden is None or len(hidden) <= 2 else
                hidden[(1 if self.use_qrnn else 0) + idx])
            if hid is not None: new_hidden.append(hid)
            # Cast to half to save a tiny bit of space =]
            new_mems.append(mem.half())

        if self.use_qrnn and self.qrnn_outer:
            if hidden is None:
                hidden = None
            x, new_h = self.qrnn_outer(h,
                                       None if hidden is None else hidden[-1])
            new_hidden.append(new_h)
            # Trim the end off if the size is different
            x = torch.narrow(x, -1, 0, x.shape[-1] // self.ninp * self.ninp)
            # Divide the hidden size evenly into chunks
            x = x.view(*x.shape[:-1], x.shape[-1] // self.ninp, self.ninp)
            # Collapse the chunks through summation
            h = h + self.idrop(x).sum(dim=-2)
            #h = x.sum(dim=-2)

        #q = pe[-len(h):] + h
        #k = pe[-len(h):] + h
        #m, _ = checkpoint(self.mem_attn, q, k, h)
        #m = checkpoint(self.mem_boom, h + m)
        #m = self.mem_ln(m)
        new_mems.append(h)

        #h = self.final_act(h)
        #h = self.final_ln(h)
        #x = self.final_boom(h)
        #x = self.drop(x)
        #h = x + h

        h = self.drop(h)

        if return_h:
            return h, new_hidden, new_mems, None, None
        return h, new_hidden, new_mems
コード例 #9
0
ファイル: model.py プロジェクト: xzm2004260/sha-rnn
    def forward(self, h, pe, attn_mask, mem=None, hidden=None):
        new_hiddens = []

        h = self.lnq(h)

        if self.rnn:
            x, nh = self.rnn(h, None if hidden is None else hidden[0])
            new_hiddens.append(nh)
            # Trim the end off if the size is different
            ninp = h.shape[-1]
            x = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
            # Divide the hidden size evenly into chunks
            x = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
            # Collapse the chunks through summation
            #h = h + self.drop(x).sum(dim=-2)
            h = self.drop(x).sum(dim=-2)
            #h = torch.sigmoid(self.residual_gate) * h + self.drop(x).sum(dim=-2)

        if mem is not None and self.memrnn:
            #mhid = [mh.expand_as(mem) for mh in self.mem_hidden]
            x, _ = self.memrnn(mem)
            # Trim the end off if the size is different
            ninp = h.shape[-1]
            x = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
            # Divide the hidden size evenly into chunks
            x = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
            # Collapse the chunks through summation
            mem = mem + self.drop(x).sum(dim=-2)

        if self.memmix is not None and mem is not None:
            # Add a zeroed out end element as the last element can't see the next step
            shifted_m = torch.cat([mem[1:], mem[:1] * 0], dim=0)
            m = torch.cat([mem, shifted_m], dim=-1)
            mem = mem + self.gelu(self.memmix(m))

        #if mem is not None: mem = 0 * mem

        if mem is not None:
            bigh = torch.cat([mem, h], dim=0)
        else:
            bigh = h

        # Store memory at the point the model would have seen it
        #new_mem = h
        new_mem = bigh[-len(pe):]
        #
        #q = pe[-len(m):] + m
        #k = pe[-len(m):] + m
        #m, _ = checkpoint(self.mem_attn, q, k, m)
        #m = checkpoint(self.mem_boom, h + m)

        h = self.ln1(h)
        bigh = self.ln1(bigh)

        q = pe[-len(h):] + h
        k = pe[-len(bigh):] + bigh
        #print(q.shape, k.shape, bigh.shape, attn_mask.shape)
        x, _ = checkpoint(self.attn, q, k, bigh, attn_mask)
        x = self.drop(x)
        h = x + h

        h = self.ln2(h)

        if self.ff:
            x = checkpoint(self.ff, h)
            x = self.drop(x)
            h = x + h

        if self.qrnn:
            h = self.lnq2(h)
            x, nh = self.rnn(h, None if hidden is None else hidden[1])
            new_hiddens.append(nh)
            # Trim the end off if the size is different
            ninp = h.shape[-1]
            x = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
            # Divide the hidden size evenly into chunks
            x = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
            # Collapse the chunks through summation
            h = h + self.drop(x).sum(dim=-2)
            #h = self.drop(x).sum(dim=-2)

        return h, new_mem, new_hiddens