コード例 #1
0
    def forward(self, inp, gts=None, task=None):

        x_size = inp.size()
        x = self.mod1(inp)
        m2 = self.mod2(self.pool2(x))
        x = self.mod3(self.pool3(m2))
        x = self.mod4(x)
        x = self.mod5(x)
        x = self.mod6(x, task=task)
        x = self.mod7(x, task=task)
        x = self.aspp2(x)

        dec0_up = self.bot_aspp2(x)
        dec0_fine = self.bot_fine2(m2)
        dec0_up = Upsample(dec0_up, m2.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)

        dec1 = self.final2(dec0)
        out = Upsample(dec1, x_size[2:])

        if self.training:
            print(out.size())
            print(gts.size())
            return self.criterion(out, gts)

        return out  #[:,:19,:,:]
コード例 #2
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None):
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)

        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        conv_aspp_ = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        # spatial attention here.
        #conv_aspp_ = self.asnb(conv_s2, conv_aspp_)
        conv_aspp_ = Upsample(conv_aspp_, conv_aspp_.size()[2:])
        conv_aspp_shape = conv_aspp_.shape
        conv_aspp_ = self.adnb([conv_aspp_],
                              masks=[conv_aspp_.new_zeros((conv_aspp_.shape[0], conv_aspp_.shape[2], conv_aspp_.shape[3]), dtype=torch.bool)],
                              pos_embeds=[None])
        conv_aspp_ = conv_aspp_.transpose(-1, -2).view(conv_aspp_shape)

        conv_aspp = Upsample(conv_aspp_, s2_features.size()[2:])

        cat_s4 = [conv_s2, conv_aspp]
        cat_s4_attn = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        cat_s4_attn = torch.cat(cat_s4_attn, 1)

        final = self.final(cat_s4)
        scale_attn = self.scale_attn(cat_s4_attn)

        out = Upsample(final, x_size[2:])
        scale_attn = Upsample(scale_attn, x_size[2:])

        if self.attn_2b:
            logit_attn = scale_attn[:, 0:1, :, :]
            aspp_attn = scale_attn[:, 1:, :, :]
        else:
            logit_attn = scale_attn
            aspp_attn = scale_attn

        return out, logit_attn, aspp_attn, aspp