def forward(self, inp, gts=None, task=None): x_size = inp.size() x = self.mod1(inp) m2 = self.mod2(self.pool2(x)) x = self.mod3(self.pool3(m2)) x = self.mod4(x) x = self.mod5(x) x = self.mod6(x, task=task) x = self.mod7(x, task=task) x = self.aspp2(x) dec0_up = self.bot_aspp2(x) dec0_fine = self.bot_fine2(m2) dec0_up = Upsample(dec0_up, m2.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final2(dec0) out = Upsample(dec1, x_size[2:]) if self.training: print(out.size()) print(gts.size()) return self.criterion(out, gts) return out #[:,:19,:,:]
def forward(self, x, gts=None): x_size = x.size() # 800 x0 = self.layer0(x) # 400 x1 = self.layer1(x0) # 400 x2 = self.layer2(x1) # 100 x3 = self.layer3(x2) # 100 x4 = self.layer4(x3) # 100 xp = self.aspp(x4) dec0_up = self.bot_aspp(xp) if self.skip == 'm1': dec0_fine = self.bot_fine(x1) dec0_up = Upsample(dec0_up, x1.size()[2:]) else: dec0_fine = self.bot_fine(x2) dec0_up = Upsample(dec0_up, x2.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final(dec0) main_out = Upsample(dec1, x_size[2:]) if self.training: return self.criterion(main_out, gts) return main_out
def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None): x_size = x.size() s2_features, _, final_features = self.backbone(x) aspp = self.aspp(final_features) if self.fuse_aspp and \ aspp_lo is not None and aspp_attn is not None: aspp_attn = scale_as(aspp_attn, aspp) aspp_lo = scale_as(aspp_lo, aspp) aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp conv_aspp = self.bot_aspp(aspp) conv_s2 = self.bot_fine(s2_features) conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) cat_s4 = [conv_s2, conv_aspp] cat_s4_attn = [conv_s2, conv_aspp] cat_s4 = torch.cat(cat_s4, 1) cat_s4_attn = torch.cat(cat_s4_attn, 1) final = self.final(cat_s4) scale_attn = self.scale_attn(cat_s4_attn) out = Upsample(final, x_size[2:]) scale_attn = Upsample(scale_attn, x_size[2:]) if self.attn_2b: logit_attn = scale_attn[:, 0:1, :, :] aspp_attn = scale_attn[:, 1:, :, :] else: logit_attn = scale_attn aspp_attn = scale_attn return out, logit_attn, aspp_attn, aspp
def forward_with_smear(self, x, smear_layer, smear_mode, init_spIndx, final_spIndx, psp_assoc, spShape): _spix_pool_ = lambda xx: spix_pool(xx, init_spIndx, psp_assoc, final_spIndx, smear_mode, spShape) x_size = x.size() if smear_layer == 'input': x = _spix_pool_(x) x = self.mod1(x) if smear_layer == 'mod1': x = _spix_pool_(x) m2 = self.mod2(self.pool2(x)) if smear_layer == 'mod2': m2 = _spix_pool_(m2) x = self.mod3(self.pool3(m2)) if smear_layer == 'mod3': x = _spix_pool_(x) x = self.mod4(x) if smear_layer == 'mod4': x = _spix_pool_(x) x = self.mod5(x) if smear_layer == 'mod5': x = _spix_pool_(x) x = self.mod6(x) if smear_layer == 'mod6': x = _spix_pool_(x) x = self.mod7(x) if smear_layer == 'mod7': x = _spix_pool_(x) x = self.aspp(x) if smear_layer == 'aspp': x = _spix_pool_(x) dec0_fine = self.bot_fine(m2) dec0_up = Upsample(self.bot_aspp(x), m2.size()[2:]) dec0 = torch.cat([dec0_fine, dec0_up], 1) if smear_layer == 'dec0': dec0 = _spix_pool_(dec0) dec1 = self.final(dec0) if smear_layer == 'dec1': dec1 = _spix_pool_(dec1) out = Upsample(dec1, x_size[2:]) if smear_layer == 'out': out = _spix_pool_(out) return out
def forward(self, inp, gts=None): x_size = inp.size() x = self.mod1(inp) m2 = self.mod2(self.pool2(x)) x = self.mod3(self.pool3(m2)) x = self.mod4(x) x = self.mod5(x) x = self.mod6(x) x = self.mod7(x) x = self.aspp(x) dec0_up = self.bot_aspp(x) dec0_fine = self.bot_fine(m2) dec0_up = Upsample(dec0_up, m2.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final(dec0) out = Upsample(dec1, x_size[2:]) # if gts is not None and self.training: # return self.criterion(out, gts) return out
def none_spix_gather_smear(pFeat, init_spIndx, spShape): with torch.no_grad(): _, _, H1, W1 = init_spIndx.shape _, _, H2, W2 = pFeat.shape if H1 == H2 and W1 == W2: pass else: pFeat = Upsample(pFeat, size=(H1, W1)) # upsample by interp pFeat = Upsample(pFeat, size=(H2, W2)) # downsample by interp return pFeat
def _fwd(self, x): x_size = x.size()[2:] _, _, high_level_features = self.backbone(x) cls_out, aux_out, ocr_mid_feats = self.ocr(high_level_features) attn = self.scale_attn(ocr_mid_feats) aux_out = Upsample(aux_out, x_size) cls_out = Upsample(cls_out, x_size) attn = Upsample(attn, x_size) return {'cls_out': cls_out, 'aux_out': aux_out, 'logit_attn': attn}
def hard_spix_gather_smear(pFeat, final_spIndx, spShape): with torch.no_grad(): _, _, H1, W1 = final_spIndx.shape _, _, H2, W2 = pFeat.shape K = final_spIndx.max().item() + 1 if H1 == H2 and W1 == W2: spFeat, _ = svx.spFeatGather2d(pFeat, final_spIndx, K) pFeat = svx.spFeatSmear2d(spFeat, final_spIndx) else: pFeat = Upsample(pFeat, size=(H1, W1)) # upsample by interp spFeat, _ = svx.spFeatGather2d(pFeat, final_spIndx, K) pFeat = svx.spFeatSmear2d(spFeat, final_spIndx) pFeat = Upsample(pFeat, size=(H2, W2)) # downsample by interp return pFeat
def forward(self, x, edge): x_size = x.size() img_features = self.img_pooling(x) img_features = self.img_conv(img_features) img_features = Upsample(img_features, x_size[2:]) out = img_features edge_features = Upsample(edge, x_size[2:]) edge_features = self.edge_conv(edge_features) out = torch.cat((out, edge_features), 1) for f in self.features: y = f(x) out = torch.cat((out, y), 1) return out
def _fwd_feature(self, x): x_size = x.size() s2_features, _, final_features = self.backbone(x) aspp = self.aspp(final_features) conv_aspp = self.bot_aspp(aspp) conv_s2 = self.bot_fine(s2_features) conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) cat_s4 = [conv_s2, conv_aspp] cat_s4_attn = [conv_s2, conv_aspp] cat_s4 = torch.cat(cat_s4, 1) cat_s4_attn = torch.cat(cat_s4_attn, 1) final = self.final(cat_s4) out = Upsample(final, x_size[2:]) return out, aspp, cat_s4_attn
def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None): x_size = x.size() _, _, final_features = self.backbone(x) aspp = self.aspp(final_features) aspp = self.bot_aspp(aspp) final = self.final(aspp) scale_attn = self.scale_attn(aspp) out = Upsample(final, x_size[2:]) scale_attn = Upsample(scale_attn, x_size[2:]) logit_attn = scale_attn aspp_attn = scale_attn return out, logit_attn, aspp_attn, aspp
def soft_spix_gather_smear(pFeat, init_spIndx, psp_assoc, spShape): with torch.no_grad(): _, _, H1, W1 = init_spIndx.shape _, _, H2, W2 = pFeat.shape Kh, Kw = spShape K = Kh * Kw if H1 == H2 and W1 == W2: spFeat, _ = svx.spFeatUpdate2d(pFeat, psp_assoc, init_spIndx, Kh, Kw) pFeat = svx.spFeatSoftSmear2d(spFeat, psp_assoc, init_spIndx, Kh, Kw) else: pFeat = Upsample(pFeat, size=(H1, W1)) # upsample by interp spFeat, _ = svx.spFeatUpdate2d(pFeat, psp_assoc, init_spIndx, Kh, Kw) pFeat = svx.spFeatSoftSmear2d(spFeat, psp_assoc, init_spIndx, Kh, Kw) pFeat = Upsample(pFeat, size=(H2, W2)) # downsample by interp return pFeat
def forward(self, inputs): assert 'images' in inputs x = inputs['images'] x_size = x.size() s2_features, _, final_features = self.backbone(x) aspp = self.aspp(final_features) conv_aspp = self.bot_aspp(aspp) conv_s2 = self.bot_fine(s2_features) conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) cat_s4 = [conv_s2, conv_aspp] cat_s4 = torch.cat(cat_s4, 1) final = self.final(cat_s4) out = Upsample(final, x_size[2:]) if self.training: assert 'gts' in inputs gts = inputs['gts'] return self.criterion(out, gts) return {'pred': out}
def _fwd_attn_rev(self, x, cat_s4_attn): x_size = x.size() scale_attn_rev = self.scale_attn_rev(cat_s4_attn) scale_attn_rev = Upsample(scale_attn_rev, x_size[2:]) if self.attn_2b: logit_attn_rev = scale_attn_rev[:, 0:1, :, :] aspp_attn_rev = scale_attn_rev[:, 1:, :, :] else: logit_attn_rev = scale_attn_rev aspp_attn_rev = scale_attn_rev return logit_attn_rev, aspp_attn_rev
def _fwd_attn(self, x, cat_s4_attn): x_size = x.size() scale_attn = self.scale_attn(cat_s4_attn) scale_attn = Upsample(scale_attn, x_size[2:]) if self.attn_2b: logit_attn = scale_attn[:, 0:1, :, :] aspp_attn = scale_attn[:, 1:, :, :] else: logit_attn = scale_attn aspp_attn = scale_attn return logit_attn, aspp_attn
def _fwd(self, x, aspp_lo=None, aspp_attn=None): """ Run the network, and return final feature and logit predictions """ x_size = x.size() s2_features, _, final_features = self.backbone(x) aspp = self.aspp(final_features) if self.fuse_aspp and \ aspp_lo is not None and aspp_attn is not None: aspp_attn = scale_as(aspp_attn, aspp) aspp_lo = scale_as(aspp_lo, aspp) aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp conv_aspp = self.bot_aspp(aspp) conv_s2 = self.bot_fine(s2_features) conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) cat_s4 = [conv_s2, conv_aspp] cat_s4 = torch.cat(cat_s4, 1) final = self.final(cat_s4) out = Upsample(final, x_size[2:]) return out, cat_s4
def forward(self, x, gts=None, smear_layer='', smear_mode='hard', init_spIndx=None, final_spIndx=None, psp_assoc=None, spShape=None): if smear_layer != '': return self.forward_with_smear(x, smear_layer, smear_mode, init_spIndx, final_spIndx, psp_assoc, spShape) x_size = x.size() x = self.mod1(x) m2 = self.mod2(self.pool2(x)) x = self.mod3(self.pool3(m2)) x = self.mod4(x) x = self.mod5(x) x = self.mod6(x) x = self.mod7(x) x = self.aspp(x) dec0_up = self.bot_aspp(x) dec0_fine = self.bot_fine(m2) dec0_up = Upsample(dec0_up, m2.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final(dec0) out = Upsample(dec1, x_size[2:]) if self.training: return self.criterion(out, gts) return out
def forward(self, inp, gts=None): x_size = inp.size() x = self.mod1(inp) m2 = self.mod2(self.pool2(x)) x = self.mod3(self.pool3(m2)) x = self.mod4(x) x = self.mod5(x) x1 = self.mod6(x, task='semantic') x1 = self.mod7(x1, task='semantic') x2 = self.mod6(x, task='traversability') x2 = self.mod7(x2, task='traversability') xaspp = self.aspp(x1) dec0_up = self.bot_aspp(xaspp) dec0_fine = self.bot_fine(m2) dec0_up = Upsample(dec0_up, m2.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final(dec0) out1 = Upsample(dec1, x_size[2:]) xaspp = self.aspp2(x2) dec0_up = self.bot_aspp2(xaspp) dec0_fine = self.bot_fine2(m2) dec0_up = Upsample(dec0_up, m2.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final2(dec0) out2 = Upsample(dec1, x_size[2:]) # dec1 = self.final2(dec0) # out2 = Upsample(dec1, x_size[2:]) return out1, out2
def forward(self, inputs): assert 'images' in inputs x = inputs['images'] x_size = x.size() _, _, final_features = self.backbone(x) aspp = self.aspp(final_features) final = self.final(aspp) out = Upsample(final, x_size[2:]) if self.training: assert 'gts' in inputs gts = inputs['gts'] return self.criterion(out, gts) return {'pred': out}
def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None): x_size = x.size() s2_features, _, final_features = self.backbone(x) aspp = self.aspp(final_features) if self.fuse_aspp and \ aspp_lo is not None and aspp_attn is not None: aspp_attn = scale_as(aspp_attn, aspp) aspp_lo = scale_as(aspp_lo, aspp) aspp = aspp_attn * aspp_lo + (1 - aspp_attn) * aspp conv_aspp_ = self.bot_aspp(aspp) conv_s2 = self.bot_fine(s2_features) # spatial attention here. #conv_aspp_ = self.asnb(conv_s2, conv_aspp_) conv_aspp_ = Upsample(conv_aspp_, conv_aspp_.size()[2:]) conv_aspp_shape = conv_aspp_.shape conv_aspp_ = self.adnb([conv_aspp_], masks=[conv_aspp_.new_zeros((conv_aspp_.shape[0], conv_aspp_.shape[2], conv_aspp_.shape[3]), dtype=torch.bool)], pos_embeds=[None]) conv_aspp_ = conv_aspp_.transpose(-1, -2).view(conv_aspp_shape) conv_aspp = Upsample(conv_aspp_, s2_features.size()[2:]) cat_s4 = [conv_s2, conv_aspp] cat_s4_attn = [conv_s2, conv_aspp] cat_s4 = torch.cat(cat_s4, 1) cat_s4_attn = torch.cat(cat_s4_attn, 1) final = self.final(cat_s4) scale_attn = self.scale_attn(cat_s4_attn) out = Upsample(final, x_size[2:]) scale_attn = Upsample(scale_attn, x_size[2:]) if self.attn_2b: logit_attn = scale_attn[:, 0:1, :, :] aspp_attn = scale_attn[:, 1:, :, :] else: logit_attn = scale_attn aspp_attn = scale_attn return out, logit_attn, aspp_attn, aspp
def forward(self, inputs): x = inputs['images'] x_size = x.size() _, _, final_features = self.backbone(x) aspp = self.aspp(final_features) aspp = self.bot_aspp(aspp) pred = self.final(aspp) pred = Upsample(pred, x_size[2:]) if self.training: assert 'gts' in inputs gts = inputs['gts'] loss = self.criterion(pred, gts) return loss else: output_dict = {'pred': pred} return output_dict
def forward(self, inp_img, audio1, audio6, gts=None,gts_diff_2=None, gts_diff_5=None,gts_depth=None): '''batch_size, timesteps, C, H, W = audio1.size() c_in1 = audio1.view(batch_size * timesteps, C, H, W);c_in2 = audio6.view(batch_size * timesteps, C, H, W); audio_conv1feature = self.audionet_convlayer1(c_in1);audio_conv1feature2 = self.audionet_convlayer1(c_in2) audio_conv2feature = self.audionet_convlayer2(audio_conv1feature);audio_conv2feature2 = self.audionet_convlayer2(audio_conv1feature2) audio_conv3feature = self.audionet_convlayer3(audio_conv2feature);audio_conv3feature2 = self.audionet_convlayer3(audio_conv2feature2) audio_conv4feature = self.audionet_convlayer4(audio_conv3feature);audio_conv4feature2 = self.audionet_convlayer4(audio_conv3feature2) audio_conv5feature = self.audionet_convlayer5(audio_conv4feature);audio_conv5feature2 = self.audionet_convlayer5(audio_conv4feature2) audio_feat = audio_conv5feature.view(audio_conv5feature.shape[0], -1, 1, 1);audio_feat2 = audio_conv5feature2.view(audio_conv5feature2.shape[0], -1, 1, 1); audio_feat = self.conv1x1(audio_feat);audio_feat2 = self.conv1x1(audio_feat2) r_in = audio_feat.view(batch_size, timesteps, -1);r_in2 = audio_feat2.view(batch_size, timesteps, -1) ''' out_aud1 = self.forward_Seg(audio1);out_aud6 = self.forward_Seg(audio6) #print(inp.size()) #x_size = inp_img.size() #out_aud1=self.unet(audio1);out_aud6 = self.unet(audio6); #x = self.mod1(inp_img) #m2 = self.mod2(self.pool2(x)) #x = self.mod3(self.pool3(m2)) #x = self.mod4(x) #x = self.mod5(x) #x = self.mod6(x) #x = self.mod7(x) mask_prediction, mask_prediction2 = self.forward_SASR(audio1, audio6); #print(mask_prediction2.shape,gts_diff_5.shape) loss = self.MSEcriterion(mask_prediction,gts_diff_2)+ self.MSEcriterion(mask_prediction2,gts_diff_5) #x = self.aspp(x) #dec0_up = self.bot_aspp(x);print(dec0_up.shape) dec0_aud1 = Upsample(out_aud1, [60,120]);dec0_aud1 = self.bot_aud1(dec0_aud1); dec0_aud6 = Upsample(out_aud6, [60,120]);dec0_aud6 = self.bot_aud1(dec0_aud6); dec0_aud = [dec0_aud1, dec0_aud6];dec0_aud = torch.cat(dec0_aud,1);dec0_aud = self.bot_multiaud(dec0_aud); #dec0_up = [dec0_up,dec0_aud];dec0_up = torch.cat(dec0_up,1); dec0_auds= self.aspp(dec0_aud);dec0_audd = self.depthaspp(dec0_aud); dec0_up = self.bot_aspp(dec0_auds);dec0_upd = self.bot_depthaspp(dec0_audd); #print(dec0_aud.shape, dec0_up.shape); #dec0_fine = self.bot_fine(m2) dec0_up = Upsample(dec0_up,[240,480]);dec0_upd = Upsample(dec0_upd, [160,512]); #dec0 = [dec0_fine, dec0_up] #dec0 = torch.cat(dec0, 1) #print(dec0.shape, out_aud1.shape, out_aud6.shape) dec1 = self.final(dec0_up);dec1d = self.final_depth(dec0_upd) out = Upsample(dec1,[480,960]);outd = Upsample(dec1d, [320,1024]) #print(out.size()) #out=self.final(out) #print(out.size(),x_size) #out = Upsample(out, x_size[1:]) #print(out.size(),gts.size()) #print(out[0,0,0:10,0],gts[0,0:10,0]) if self.training: if loss <5.0: #print(loss,self.criterion(out, gts)) return 10*self.criterion(out, gts)+loss+0.5*self.MSEcriterion(outd,gts_depth) else: return 10*self.criterion(out, gts)+0.5*self.MSEcriterion(outd,gts_depth) return out,outd
def forward(self, x, gts=None, aux_gts=None, img_gt=None, visualize=False, cal_covstat=False, apply_wtloss=True): w_arr = [] if cal_covstat: x = torch.cat(x, dim=0) x_size = x.size() # 800 if self.trunk == 'mobilenetv2' or self.trunk == 'shufflenetv2': x_tuple = self.layer0([x, w_arr]) x = x_tuple[0] w_arr = x_tuple[1] else: # ResNet if self.three_input_layer: x = self.layer0[0](x) if self.args.wt_layer[0] == 1 or self.args.wt_layer[0] == 2: x, w = self.layer0[1](x) w_arr.append(w) else: x = self.layer0[1](x) x = self.layer0[2](x) x = self.layer0[3](x) if self.args.wt_layer[1] == 1 or self.args.wt_layer[1] == 2: x, w = self.layer0[4](x) w_arr.append(w) else: x = self.layer0[4](x) x = self.layer0[5](x) x = self.layer0[6](x) if self.args.wt_layer[2] == 1 or self.args.wt_layer[2] == 2: x, w = self.layer0[7](x) w_arr.append(w) else: x = self.layer0[7](x) x = self.layer0[8](x) x = self.layer0[9](x) else: # Single Input Layer x = self.layer0[0](x) if self.args.wt_layer[2] == 1 or self.args.wt_layer[2] == 2: x, w = self.layer0[1](x) w_arr.append(w) else: x = self.layer0[1](x) x = self.layer0[2](x) x = self.layer0[3](x) x_tuple = self.layer1([x, w_arr]) # 400 low_level = x_tuple[0] x_tuple = self.layer2(x_tuple) # 100 x_tuple = self.layer3(x_tuple) # 100 aux_out = x_tuple[0] x_tuple = self.layer4(x_tuple) # 100 x = x_tuple[0] w_arr = x_tuple[1] if cal_covstat: for index, f_map in enumerate(w_arr): # Instance Whitening B, C, H, W = f_map.shape # i-th feature size (B X C X H X W) HW = H * W f_map = f_map.contiguous().view(B, C, -1) # B X C X H X W > B X C X (H X W) eye, reverse_eye = self.cov_matrix_layer[index].get_eye_matrix() f_cor = torch.bmm(f_map, f_map.transpose(1, 2)).div(HW - 1) + (self.eps * eye) # B X C X C / HW off_diag_elements = f_cor * reverse_eye #print("here", off_diag_elements.shape) self.cov_matrix_layer[index].set_variance_of_covariance(torch.var(off_diag_elements, dim=0)) return 0 x = self.aspp(x) dec0_up = self.bot_aspp(x) dec0_fine = self.bot_fine(low_level) dec0_up = Upsample(dec0_up, low_level.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final1(dec0) dec2 = self.final2(dec1) main_out = Upsample(dec2, x_size[2:]) if self.training: loss1 = self.criterion(main_out, gts) if self.args.use_wtloss: wt_loss = torch.FloatTensor([0]).cuda() if apply_wtloss: for index, f_map in enumerate(w_arr): eye, mask_matrix, margin, num_remove_cov = self.cov_matrix_layer[index].get_mask_matrix() loss = instance_whitening_loss(f_map, eye, mask_matrix, margin, num_remove_cov) wt_loss = wt_loss + loss wt_loss = wt_loss / len(w_arr) aux_out = self.dsn(aux_out) if aux_gts.dim() == 1: aux_gts = gts aux_gts = aux_gts.unsqueeze(1).float() aux_gts = nn.functional.interpolate(aux_gts, size=aux_out.shape[2:], mode='nearest') aux_gts = aux_gts.squeeze(1).long() loss2 = self.criterion_aux(aux_out, aux_gts) return_loss = [loss1, loss2] if self.args.use_wtloss: return_loss.append(wt_loss) if self.args.use_wtloss and visualize: f_cor_arr = [] for f_map in w_arr: f_cor, _ = get_covariance_matrix(f_map) f_cor_arr.append(f_cor) return_loss.append(f_cor_arr) return return_loss else: if visualize: f_cor_arr = [] for f_map in w_arr: f_cor, _ = get_covariance_matrix(f_map) f_cor_arr.append(f_cor) return main_out, f_cor_arr else: return main_out
def forward(self, x, gts=None, aux_gts=None, pos=None, attention_map=False, attention_loss=False): x_size = x.size() # 800 x = self.layer0(x) # 400 x = self.layer1(x) # 400 low_level = x x = self.layer2(x) # 100 x = self.layer3(x) # 100 aux_out = x x = self.layer4(x) # 100 if self.num_attention_layer > 0: if attention_map: attention_maps = [ torch.Tensor() for i in range(self.num_attention_layer) ] pos_maps = [ torch.Tensor() for i in range(self.num_attention_layer) ] map_index = 0 if self.args.hanet[0] == 1: if attention_map: x, attention_maps[map_index], pos_maps[ map_index] = self.hanet0(aux_out, x, pos, return_attention=True, return_posmap=True) map_index += 1 else: x = self.hanet0(aux_out, x, pos) represent = x x = self.aspp(x) if self.args.hanet[1] == 1: if attention_map: x, attention_maps[map_index], pos_maps[ map_index] = self.hanet1(represent, x, pos, return_attention=True, return_posmap=True) map_index += 1 else: x = self.hanet1(represent, x, pos) dec0_up = self.bot_aspp(x) if self.args.hanet[2] == 1: if attention_map: dec0_up, attention_maps[map_index], pos_maps[ map_index] = self.hanet2(x, dec0_up, pos, return_attention=True, return_posmap=True) map_index += 1 else: dec0_up = self.hanet2(x, dec0_up, pos) dec0_fine = self.bot_fine(low_level) dec0_up = Upsample(dec0_up, low_level.size()[2:]) dec0 = [dec0_fine, dec0_up] dec0 = torch.cat(dec0, 1) dec1 = self.final1(dec0) if self.args.hanet[3] == 1: if attention_map: dec1, attention_maps[map_index], pos_maps[ map_index] = self.hanet3(dec0, dec1, pos, return_attention=True, return_posmap=True) map_index += 1 else: dec1 = self.hanet3(dec0, dec1, pos) dec2 = self.final2(dec1) if self.args.hanet[4] == 1: if attention_map: dec2, attention_maps[map_index], pos_maps[ map_index] = self.hanet4(dec1, dec2, pos, return_attention=True, return_posmap=True) map_index += 1 elif attention_loss: dec2, last_attention = self.hanet4(dec1, dec2, pos, return_attention=False, return_posmap=False, attention_loss=True) else: dec2 = self.hanet4(dec1, dec2, pos) main_out = Upsample(dec2, x_size[2:]) if self.training: loss1 = self.criterion(main_out, gts) if self.args.aux_loss is True: aux_out = self.dsn(aux_out) if aux_gts.dim() == 1: aux_gts = gts aux_gts = aux_gts.unsqueeze(1).float() aux_gts = nn.functional.interpolate(aux_gts, size=aux_out.shape[2:], mode='nearest') aux_gts = aux_gts.squeeze(1).long() loss2 = self.criterion_aux(aux_out, aux_gts) if attention_loss: return (loss1, loss2, last_attention) else: return (loss1, loss2) else: if attention_loss: return (loss1, last_attention) else: return loss1 else: if attention_map: return main_out, attention_maps, pos_maps else: return main_out