Пример #1
0
    def forward(self, inp, gts=None, task=None):

        x_size = inp.size()
        x = self.mod1(inp)
        m2 = self.mod2(self.pool2(x))
        x = self.mod3(self.pool3(m2))
        x = self.mod4(x)
        x = self.mod5(x)
        x = self.mod6(x, task=task)
        x = self.mod7(x, task=task)
        x = self.aspp2(x)

        dec0_up = self.bot_aspp2(x)
        dec0_fine = self.bot_fine2(m2)
        dec0_up = Upsample(dec0_up, m2.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)

        dec1 = self.final2(dec0)
        out = Upsample(dec1, x_size[2:])

        if self.training:
            print(out.size())
            print(gts.size())
            return self.criterion(out, gts)

        return out  #[:,:19,:,:]
    def forward(self, x, gts=None):

        x_size = x.size()  # 800
        x0 = self.layer0(x)  # 400
        x1 = self.layer1(x0)  # 400
        x2 = self.layer2(x1)  # 100
        x3 = self.layer3(x2)  # 100
        x4 = self.layer4(x3)  # 100
        xp = self.aspp(x4)

        dec0_up = self.bot_aspp(xp)
        if self.skip == 'm1':
            dec0_fine = self.bot_fine(x1)
            dec0_up = Upsample(dec0_up, x1.size()[2:])
        else:
            dec0_fine = self.bot_fine(x2)
            dec0_up = Upsample(dec0_up, x2.size()[2:])

        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)
        dec1 = self.final(dec0)
        main_out = Upsample(dec1, x_size[2:])

        if self.training:
            return self.criterion(main_out, gts)

        return main_out
Пример #3
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None):
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        conv_aspp = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        conv_aspp = Upsample(conv_aspp, s2_features.size()[2:])
        cat_s4 = [conv_s2, conv_aspp]
        cat_s4_attn = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        cat_s4_attn = torch.cat(cat_s4_attn, 1)

        final = self.final(cat_s4)
        scale_attn = self.scale_attn(cat_s4_attn)

        out = Upsample(final, x_size[2:])
        scale_attn = Upsample(scale_attn, x_size[2:])

        if self.attn_2b:
            logit_attn = scale_attn[:, 0:1, :, :]
            aspp_attn = scale_attn[:, 1:, :, :]
        else:
            logit_attn = scale_attn
            aspp_attn = scale_attn

        return out, logit_attn, aspp_attn, aspp
Пример #4
0
 def forward_with_smear(self, x, smear_layer, smear_mode, init_spIndx,
                        final_spIndx, psp_assoc, spShape):
     _spix_pool_ = lambda xx: spix_pool(xx, init_spIndx, psp_assoc,
                                        final_spIndx, smear_mode, spShape)
     x_size = x.size()
     if smear_layer == 'input': x = _spix_pool_(x)
     x = self.mod1(x)
     if smear_layer == 'mod1': x = _spix_pool_(x)
     m2 = self.mod2(self.pool2(x))
     if smear_layer == 'mod2': m2 = _spix_pool_(m2)
     x = self.mod3(self.pool3(m2))
     if smear_layer == 'mod3': x = _spix_pool_(x)
     x = self.mod4(x)
     if smear_layer == 'mod4': x = _spix_pool_(x)
     x = self.mod5(x)
     if smear_layer == 'mod5': x = _spix_pool_(x)
     x = self.mod6(x)
     if smear_layer == 'mod6': x = _spix_pool_(x)
     x = self.mod7(x)
     if smear_layer == 'mod7': x = _spix_pool_(x)
     x = self.aspp(x)
     if smear_layer == 'aspp': x = _spix_pool_(x)
     dec0_fine = self.bot_fine(m2)
     dec0_up = Upsample(self.bot_aspp(x), m2.size()[2:])
     dec0 = torch.cat([dec0_fine, dec0_up], 1)
     if smear_layer == 'dec0': dec0 = _spix_pool_(dec0)
     dec1 = self.final(dec0)
     if smear_layer == 'dec1': dec1 = _spix_pool_(dec1)
     out = Upsample(dec1, x_size[2:])
     if smear_layer == 'out': out = _spix_pool_(out)
     return out
Пример #5
0
    def forward(self, inp, gts=None):

        x_size = inp.size()
        x = self.mod1(inp)
        m2 = self.mod2(self.pool2(x))
        x = self.mod3(self.pool3(m2))
        x = self.mod4(x)
        x = self.mod5(x)
        x = self.mod6(x)
        x = self.mod7(x)

        x = self.aspp(x)
        dec0_up = self.bot_aspp(x)

        dec0_fine = self.bot_fine(m2)
        dec0_up = Upsample(dec0_up, m2.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)

        dec1 = self.final(dec0)
        out = Upsample(dec1, x_size[2:])

        #         if gts is not None and self.training:
        #             return self.criterion(out, gts)
        return out
Пример #6
0
def none_spix_gather_smear(pFeat, init_spIndx, spShape):
    with torch.no_grad():
        _, _, H1, W1 = init_spIndx.shape
        _, _, H2, W2 = pFeat.shape
        if H1 == H2 and W1 == W2:
            pass
        else:
            pFeat = Upsample(pFeat, size=(H1, W1))  # upsample by interp
            pFeat = Upsample(pFeat, size=(H2, W2))  # downsample by interp
        return pFeat
Пример #7
0
    def _fwd(self, x):
        x_size = x.size()[2:]

        _, _, high_level_features = self.backbone(x)
        cls_out, aux_out, ocr_mid_feats = self.ocr(high_level_features)
        attn = self.scale_attn(ocr_mid_feats)

        aux_out = Upsample(aux_out, x_size)
        cls_out = Upsample(cls_out, x_size)
        attn = Upsample(attn, x_size)

        return {'cls_out': cls_out, 'aux_out': aux_out, 'logit_attn': attn}
Пример #8
0
def hard_spix_gather_smear(pFeat, final_spIndx, spShape):
    with torch.no_grad():
        _, _, H1, W1 = final_spIndx.shape
        _, _, H2, W2 = pFeat.shape
        K = final_spIndx.max().item() + 1
        if H1 == H2 and W1 == W2:
            spFeat, _ = svx.spFeatGather2d(pFeat, final_spIndx, K)
            pFeat = svx.spFeatSmear2d(spFeat, final_spIndx)
        else:
            pFeat = Upsample(pFeat, size=(H1, W1))  # upsample by interp
            spFeat, _ = svx.spFeatGather2d(pFeat, final_spIndx, K)
            pFeat = svx.spFeatSmear2d(spFeat, final_spIndx)
            pFeat = Upsample(pFeat, size=(H2, W2))  # downsample by interp
        return pFeat
Пример #9
0
    def forward(self, x, edge):
        x_size = x.size()

        img_features = self.img_pooling(x)
        img_features = self.img_conv(img_features)
        img_features = Upsample(img_features, x_size[2:])
        out = img_features
        edge_features = Upsample(edge, x_size[2:])
        edge_features = self.edge_conv(edge_features)
        out = torch.cat((out, edge_features), 1)

        for f in self.features:
            y = f(x)
            out = torch.cat((out, y), 1)
        return out
Пример #10
0
    def _fwd_feature(self, x):
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)

        conv_aspp = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        conv_aspp = Upsample(conv_aspp, s2_features.size()[2:])
        cat_s4 = [conv_s2, conv_aspp]
        cat_s4_attn = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        cat_s4_attn = torch.cat(cat_s4_attn, 1)
        final = self.final(cat_s4)
        out = Upsample(final, x_size[2:])
        return out, aspp, cat_s4_attn
Пример #11
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None):
        x_size = x.size()
        _, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)
        aspp = self.bot_aspp(aspp)

        final = self.final(aspp)
        scale_attn = self.scale_attn(aspp)

        out = Upsample(final, x_size[2:])
        scale_attn = Upsample(scale_attn, x_size[2:])

        logit_attn = scale_attn
        aspp_attn = scale_attn

        return out, logit_attn, aspp_attn, aspp
Пример #12
0
def soft_spix_gather_smear(pFeat, init_spIndx, psp_assoc, spShape):
    with torch.no_grad():
        _, _, H1, W1 = init_spIndx.shape
        _, _, H2, W2 = pFeat.shape
        Kh, Kw = spShape
        K = Kh * Kw
        if H1 == H2 and W1 == W2:
            spFeat, _ = svx.spFeatUpdate2d(pFeat, psp_assoc, init_spIndx, Kh,
                                           Kw)
            pFeat = svx.spFeatSoftSmear2d(spFeat, psp_assoc, init_spIndx, Kh,
                                          Kw)
        else:
            pFeat = Upsample(pFeat, size=(H1, W1))  # upsample by interp
            spFeat, _ = svx.spFeatUpdate2d(pFeat, psp_assoc, init_spIndx, Kh,
                                           Kw)
            pFeat = svx.spFeatSoftSmear2d(spFeat, psp_assoc, init_spIndx, Kh,
                                          Kw)
            pFeat = Upsample(pFeat, size=(H2, W2))  # downsample by interp
        return pFeat
Пример #13
0
    def forward(self, inputs):
        assert 'images' in inputs
        x = inputs['images']

        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)
        conv_aspp = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        conv_aspp = Upsample(conv_aspp, s2_features.size()[2:])
        cat_s4 = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        final = self.final(cat_s4)
        out = Upsample(final, x_size[2:])

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            return self.criterion(out, gts)

        return {'pred': out}
Пример #14
0
    def _fwd_attn_rev(self, x, cat_s4_attn):
        x_size = x.size()
        scale_attn_rev = self.scale_attn_rev(cat_s4_attn)
        scale_attn_rev = Upsample(scale_attn_rev, x_size[2:])

        if self.attn_2b:
            logit_attn_rev = scale_attn_rev[:, 0:1, :, :]
            aspp_attn_rev = scale_attn_rev[:, 1:, :, :]
        else:
            logit_attn_rev = scale_attn_rev
            aspp_attn_rev = scale_attn_rev

        return logit_attn_rev, aspp_attn_rev
Пример #15
0
    def _fwd_attn(self, x, cat_s4_attn):
        x_size = x.size()
        scale_attn = self.scale_attn(cat_s4_attn)
        scale_attn = Upsample(scale_attn, x_size[2:])

        if self.attn_2b:
            logit_attn = scale_attn[:, 0:1, :, :]
            aspp_attn = scale_attn[:, 1:, :, :]
        else:
            logit_attn = scale_attn
            aspp_attn = scale_attn

        return logit_attn, aspp_attn
Пример #16
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None):
        """
        Run the network, and return final feature and logit predictions
        """
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        conv_aspp = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        conv_aspp = Upsample(conv_aspp, s2_features.size()[2:])
        cat_s4 = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        final = self.final(cat_s4)
        out = Upsample(final, x_size[2:])

        return out, cat_s4
Пример #17
0
    def forward(self,
                x,
                gts=None,
                smear_layer='',
                smear_mode='hard',
                init_spIndx=None,
                final_spIndx=None,
                psp_assoc=None,
                spShape=None):
        if smear_layer != '':
            return self.forward_with_smear(x, smear_layer, smear_mode,
                                           init_spIndx, final_spIndx,
                                           psp_assoc, spShape)

        x_size = x.size()
        x = self.mod1(x)
        m2 = self.mod2(self.pool2(x))
        x = self.mod3(self.pool3(m2))
        x = self.mod4(x)
        x = self.mod5(x)
        x = self.mod6(x)
        x = self.mod7(x)
        x = self.aspp(x)
        dec0_up = self.bot_aspp(x)

        dec0_fine = self.bot_fine(m2)
        dec0_up = Upsample(dec0_up, m2.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)

        dec1 = self.final(dec0)
        out = Upsample(dec1, x_size[2:])

        if self.training:
            return self.criterion(out, gts)

        return out
Пример #18
0
    def forward(self, inp, gts=None):

        x_size = inp.size()
        x = self.mod1(inp)
        m2 = self.mod2(self.pool2(x))
        x = self.mod3(self.pool3(m2))
        x = self.mod4(x)
        x = self.mod5(x)
        x1 = self.mod6(x, task='semantic')
        x1 = self.mod7(x1, task='semantic')

        x2 = self.mod6(x, task='traversability')
        x2 = self.mod7(x2, task='traversability')

        xaspp = self.aspp(x1)
        dec0_up = self.bot_aspp(xaspp)
        dec0_fine = self.bot_fine(m2)
        dec0_up = Upsample(dec0_up, m2.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)
        dec1 = self.final(dec0)
        out1 = Upsample(dec1, x_size[2:])

        xaspp = self.aspp2(x2)
        dec0_up = self.bot_aspp2(xaspp)
        dec0_fine = self.bot_fine2(m2)
        dec0_up = Upsample(dec0_up, m2.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)
        dec1 = self.final2(dec0)
        out2 = Upsample(dec1, x_size[2:])

        #        dec1 = self.final2(dec0)
        #        out2 = Upsample(dec1, x_size[2:])

        return out1, out2
Пример #19
0
    def forward(self, inputs):
        assert 'images' in inputs
        x = inputs['images']

        x_size = x.size()
        _, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)
        final = self.final(aspp)
        out = Upsample(final, x_size[2:])

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            return self.criterion(out, gts)

        return {'pred': out}
Пример #20
0
    def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None):
        x_size = x.size()
        s2_features, _, final_features = self.backbone(x)

        aspp = self.aspp(final_features)

        if self.fuse_aspp and \
           aspp_lo is not None and aspp_attn is not None:
            aspp_attn = scale_as(aspp_attn, aspp)
            aspp_lo = scale_as(aspp_lo, aspp)
            aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp

        conv_aspp_ = self.bot_aspp(aspp)
        conv_s2 = self.bot_fine(s2_features)
        # spatial attention here.
        #conv_aspp_ = self.asnb(conv_s2, conv_aspp_)
        conv_aspp_ = Upsample(conv_aspp_, conv_aspp_.size()[2:])
        conv_aspp_shape = conv_aspp_.shape
        conv_aspp_ = self.adnb([conv_aspp_],
                              masks=[conv_aspp_.new_zeros((conv_aspp_.shape[0], conv_aspp_.shape[2], conv_aspp_.shape[3]), dtype=torch.bool)],
                              pos_embeds=[None])
        conv_aspp_ = conv_aspp_.transpose(-1, -2).view(conv_aspp_shape)

        conv_aspp = Upsample(conv_aspp_, s2_features.size()[2:])

        cat_s4 = [conv_s2, conv_aspp]
        cat_s4_attn = [conv_s2, conv_aspp]
        cat_s4 = torch.cat(cat_s4, 1)
        cat_s4_attn = torch.cat(cat_s4_attn, 1)

        final = self.final(cat_s4)
        scale_attn = self.scale_attn(cat_s4_attn)

        out = Upsample(final, x_size[2:])
        scale_attn = Upsample(scale_attn, x_size[2:])

        if self.attn_2b:
            logit_attn = scale_attn[:, 0:1, :, :]
            aspp_attn = scale_attn[:, 1:, :, :]
        else:
            logit_attn = scale_attn
            aspp_attn = scale_attn

        return out, logit_attn, aspp_attn, aspp
Пример #21
0
    def forward(self, inputs):
        x = inputs['images']
        x_size = x.size()

        _, _, final_features = self.backbone(x)
        aspp = self.aspp(final_features)
        aspp = self.bot_aspp(aspp)
        pred = self.final(aspp)
        pred = Upsample(pred, x_size[2:])

        if self.training:
            assert 'gts' in inputs
            gts = inputs['gts']
            loss = self.criterion(pred, gts)
            return loss
        else:
            output_dict = {'pred': pred}
            return output_dict
    def forward(self, inp_img, audio1, audio6, gts=None,gts_diff_2=None, gts_diff_5=None,gts_depth=None):
        '''batch_size, timesteps, C, H, W = audio1.size()
        c_in1 = audio1.view(batch_size * timesteps, C, H, W);c_in2 = audio6.view(batch_size * timesteps, C, H, W);
        audio_conv1feature = self.audionet_convlayer1(c_in1);audio_conv1feature2 = self.audionet_convlayer1(c_in2)
        audio_conv2feature = self.audionet_convlayer2(audio_conv1feature);audio_conv2feature2 = self.audionet_convlayer2(audio_conv1feature2)
        audio_conv3feature = self.audionet_convlayer3(audio_conv2feature);audio_conv3feature2 = self.audionet_convlayer3(audio_conv2feature2)
        audio_conv4feature = self.audionet_convlayer4(audio_conv3feature);audio_conv4feature2 = self.audionet_convlayer4(audio_conv3feature2)
        audio_conv5feature = self.audionet_convlayer5(audio_conv4feature);audio_conv5feature2 = self.audionet_convlayer5(audio_conv4feature2)
        audio_feat = audio_conv5feature.view(audio_conv5feature.shape[0], -1, 1, 1);audio_feat2 = audio_conv5feature2.view(audio_conv5feature2.shape[0], -1, 1, 1);
        audio_feat = self.conv1x1(audio_feat);audio_feat2 = self.conv1x1(audio_feat2)
        r_in = audio_feat.view(batch_size, timesteps, -1);r_in2 = audio_feat2.view(batch_size, timesteps, -1)
        '''
        out_aud1 = self.forward_Seg(audio1);out_aud6 = self.forward_Seg(audio6)

        #print(inp.size())
        #x_size = inp_img.size()
        #out_aud1=self.unet(audio1);out_aud6 = self.unet(audio6);
        #x = self.mod1(inp_img)
        #m2 = self.mod2(self.pool2(x))
        #x = self.mod3(self.pool3(m2))
        #x = self.mod4(x)
        #x = self.mod5(x)
        #x = self.mod6(x)
        #x = self.mod7(x)
        mask_prediction, mask_prediction2 = self.forward_SASR(audio1, audio6);

        #print(mask_prediction2.shape,gts_diff_5.shape)
        loss = self.MSEcriterion(mask_prediction,gts_diff_2)+ self.MSEcriterion(mask_prediction2,gts_diff_5)

        #x = self.aspp(x)
        #dec0_up = self.bot_aspp(x);print(dec0_up.shape)
        dec0_aud1 =  Upsample(out_aud1, [60,120]);dec0_aud1 = self.bot_aud1(dec0_aud1);
        dec0_aud6 =  Upsample(out_aud6, [60,120]);dec0_aud6 = self.bot_aud1(dec0_aud6);
        dec0_aud = [dec0_aud1, dec0_aud6];dec0_aud = torch.cat(dec0_aud,1);dec0_aud = self.bot_multiaud(dec0_aud);
        #dec0_up = [dec0_up,dec0_aud];dec0_up = torch.cat(dec0_up,1);
        dec0_auds= self.aspp(dec0_aud);dec0_audd = self.depthaspp(dec0_aud);
        dec0_up = self.bot_aspp(dec0_auds);dec0_upd = self.bot_depthaspp(dec0_audd);
        #print(dec0_aud.shape, dec0_up.shape);

        #dec0_fine = self.bot_fine(m2)
        dec0_up = Upsample(dec0_up,[240,480]);dec0_upd = Upsample(dec0_upd, [160,512]);
        #dec0 = [dec0_fine, dec0_up]
        #dec0 = torch.cat(dec0, 1)
        #print(dec0.shape, out_aud1.shape, out_aud6.shape)
        dec1 = self.final(dec0_up);dec1d = self.final_depth(dec0_upd)
        out = Upsample(dec1,[480,960]);outd = Upsample(dec1d, [320,1024])


        #print(out.size())
        #out=self.final(out)
        #print(out.size(),x_size)
        #out = Upsample(out, x_size[1:])
        #print(out.size(),gts.size())
        #print(out[0,0,0:10,0],gts[0,0:10,0])
        if self.training:
            if loss <5.0:
                #print(loss,self.criterion(out, gts))
                return 10*self.criterion(out, gts)+loss+0.5*self.MSEcriterion(outd,gts_depth)
            else:
                return 10*self.criterion(out, gts)+0.5*self.MSEcriterion(outd,gts_depth)
        return out,outd
Пример #23
0
    def forward(self, x, gts=None, aux_gts=None, img_gt=None, visualize=False, cal_covstat=False, apply_wtloss=True):
        w_arr = []

        if cal_covstat:
            x = torch.cat(x, dim=0)

        x_size = x.size()  # 800

        if self.trunk == 'mobilenetv2' or self.trunk == 'shufflenetv2':
            x_tuple = self.layer0([x, w_arr])
            x = x_tuple[0]
            w_arr = x_tuple[1]
        else:   # ResNet
            if self.three_input_layer:
                x = self.layer0[0](x)
                if self.args.wt_layer[0] == 1 or self.args.wt_layer[0] == 2:
                    x, w = self.layer0[1](x)
                    w_arr.append(w)
                else:
                    x = self.layer0[1](x)
                x = self.layer0[2](x)
                x = self.layer0[3](x)
                if self.args.wt_layer[1] == 1 or self.args.wt_layer[1] == 2:
                    x, w = self.layer0[4](x)
                    w_arr.append(w)
                else:
                    x = self.layer0[4](x)
                x = self.layer0[5](x)
                x = self.layer0[6](x)
                if self.args.wt_layer[2] == 1 or self.args.wt_layer[2] == 2:
                    x, w = self.layer0[7](x)
                    w_arr.append(w)
                else:
                    x = self.layer0[7](x)
                x = self.layer0[8](x)
                x = self.layer0[9](x)
            else:   # Single Input Layer
                x = self.layer0[0](x)
                if self.args.wt_layer[2] == 1 or self.args.wt_layer[2] == 2:
                    x, w = self.layer0[1](x)
                    w_arr.append(w)
                else:
                    x = self.layer0[1](x)
                x = self.layer0[2](x)
                x = self.layer0[3](x)

        x_tuple = self.layer1([x, w_arr])  # 400
        low_level = x_tuple[0]

        x_tuple = self.layer2(x_tuple)  # 100
        x_tuple = self.layer3(x_tuple)  # 100
        aux_out = x_tuple[0]
        x_tuple = self.layer4(x_tuple)  # 100
        x = x_tuple[0]
        w_arr = x_tuple[1]

        if cal_covstat:
            for index, f_map in enumerate(w_arr):
                # Instance Whitening
                B, C, H, W = f_map.shape  # i-th feature size (B X C X H X W)
                HW = H * W
                f_map = f_map.contiguous().view(B, C, -1)  # B X C X H X W > B X C X (H X W)
                eye, reverse_eye = self.cov_matrix_layer[index].get_eye_matrix()
                f_cor = torch.bmm(f_map, f_map.transpose(1, 2)).div(HW - 1) + (self.eps * eye)  # B X C X C / HW
                off_diag_elements = f_cor * reverse_eye
                #print("here", off_diag_elements.shape)
                self.cov_matrix_layer[index].set_variance_of_covariance(torch.var(off_diag_elements, dim=0))
            return 0

        x = self.aspp(x)
        dec0_up = self.bot_aspp(x)

        dec0_fine = self.bot_fine(low_level)
        dec0_up = Upsample(dec0_up, low_level.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)
        dec1 = self.final1(dec0)
        dec2 = self.final2(dec1)
        main_out = Upsample(dec2, x_size[2:])

        if self.training:
            loss1 = self.criterion(main_out, gts)

            if self.args.use_wtloss:
                wt_loss = torch.FloatTensor([0]).cuda()
                if apply_wtloss:
                    for index, f_map in enumerate(w_arr):
                        eye, mask_matrix, margin, num_remove_cov = self.cov_matrix_layer[index].get_mask_matrix()
                        loss = instance_whitening_loss(f_map, eye, mask_matrix, margin, num_remove_cov)
                        wt_loss = wt_loss + loss
                wt_loss = wt_loss / len(w_arr)

            aux_out = self.dsn(aux_out)
            if aux_gts.dim() == 1:
                aux_gts = gts
            aux_gts = aux_gts.unsqueeze(1).float()
            aux_gts = nn.functional.interpolate(aux_gts, size=aux_out.shape[2:], mode='nearest')
            aux_gts = aux_gts.squeeze(1).long()
            loss2 = self.criterion_aux(aux_out, aux_gts)

            return_loss = [loss1, loss2]
            if self.args.use_wtloss:
                return_loss.append(wt_loss)

            if self.args.use_wtloss and visualize:
                f_cor_arr = []
                for f_map in w_arr:
                    f_cor, _ = get_covariance_matrix(f_map)
                    f_cor_arr.append(f_cor)
                return_loss.append(f_cor_arr)
            return return_loss
        else:
            if visualize:
                f_cor_arr = []
                for f_map in w_arr:
                    f_cor, _ = get_covariance_matrix(f_map)
                    f_cor_arr.append(f_cor)
                return main_out, f_cor_arr
            else:
                return main_out
Пример #24
0
    def forward(self,
                x,
                gts=None,
                aux_gts=None,
                pos=None,
                attention_map=False,
                attention_loss=False):

        x_size = x.size()  # 800

        x = self.layer0(x)  # 400
        x = self.layer1(x)  # 400
        low_level = x
        x = self.layer2(x)  # 100

        x = self.layer3(x)  # 100

        aux_out = x
        x = self.layer4(x)  # 100

        if self.num_attention_layer > 0:
            if attention_map:
                attention_maps = [
                    torch.Tensor() for i in range(self.num_attention_layer)
                ]
                pos_maps = [
                    torch.Tensor() for i in range(self.num_attention_layer)
                ]
                map_index = 0

        if self.args.hanet[0] == 1:
            if attention_map:
                x, attention_maps[map_index], pos_maps[
                    map_index] = self.hanet0(aux_out,
                                             x,
                                             pos,
                                             return_attention=True,
                                             return_posmap=True)
                map_index += 1
            else:
                x = self.hanet0(aux_out, x, pos)

        represent = x

        x = self.aspp(x)

        if self.args.hanet[1] == 1:
            if attention_map:
                x, attention_maps[map_index], pos_maps[
                    map_index] = self.hanet1(represent,
                                             x,
                                             pos,
                                             return_attention=True,
                                             return_posmap=True)
                map_index += 1
            else:
                x = self.hanet1(represent, x, pos)

        dec0_up = self.bot_aspp(x)

        if self.args.hanet[2] == 1:
            if attention_map:
                dec0_up, attention_maps[map_index], pos_maps[
                    map_index] = self.hanet2(x,
                                             dec0_up,
                                             pos,
                                             return_attention=True,
                                             return_posmap=True)
                map_index += 1
            else:
                dec0_up = self.hanet2(x, dec0_up, pos)

        dec0_fine = self.bot_fine(low_level)
        dec0_up = Upsample(dec0_up, low_level.size()[2:])
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)
        dec1 = self.final1(dec0)

        if self.args.hanet[3] == 1:
            if attention_map:
                dec1, attention_maps[map_index], pos_maps[
                    map_index] = self.hanet3(dec0,
                                             dec1,
                                             pos,
                                             return_attention=True,
                                             return_posmap=True)
                map_index += 1
            else:
                dec1 = self.hanet3(dec0, dec1, pos)

        dec2 = self.final2(dec1)

        if self.args.hanet[4] == 1:
            if attention_map:
                dec2, attention_maps[map_index], pos_maps[
                    map_index] = self.hanet4(dec1,
                                             dec2,
                                             pos,
                                             return_attention=True,
                                             return_posmap=True)
                map_index += 1
            elif attention_loss:
                dec2, last_attention = self.hanet4(dec1,
                                                   dec2,
                                                   pos,
                                                   return_attention=False,
                                                   return_posmap=False,
                                                   attention_loss=True)
            else:
                dec2 = self.hanet4(dec1, dec2, pos)

        main_out = Upsample(dec2, x_size[2:])

        if self.training:
            loss1 = self.criterion(main_out, gts)

            if self.args.aux_loss is True:
                aux_out = self.dsn(aux_out)
                if aux_gts.dim() == 1:
                    aux_gts = gts
                aux_gts = aux_gts.unsqueeze(1).float()
                aux_gts = nn.functional.interpolate(aux_gts,
                                                    size=aux_out.shape[2:],
                                                    mode='nearest')
                aux_gts = aux_gts.squeeze(1).long()
                loss2 = self.criterion_aux(aux_out, aux_gts)
                if attention_loss:
                    return (loss1, loss2, last_attention)
                else:
                    return (loss1, loss2)
            else:
                if attention_loss:
                    return (loss1, last_attention)
                else:
                    return loss1
        else:
            if attention_map:
                return main_out, attention_maps, pos_maps
            else:
                return main_out