def _init_aspp(self): self.aspp = ASPP(self.fan_out(), 8, self.NormLayer) for m in self.aspp.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, self.NormLayer): self.from_scratch_layers.append(m) self._fix_running_stats(self.aspp) # freeze BN
class SoftMaxAE(backbone): def __init__(self, config, pre_weights=None, num_classes=21, dropout=True): super().__init__() self.cfg = config self.num_classes = num_classes self._init_weights(pre_weights) # initialise backbone weights self._fix_running_stats(self, fix_params=True) # freeze backbone BNs # Decoder self._init_aspp() self._init_decoder(num_classes) self._backbone = None self._mask_logits = None def _init_aspp(self): self.aspp = ASPP(self.fan_out(), 8, self.NormLayer) for m in self.aspp.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, self.NormLayer): self.from_scratch_layers.append(m) self._fix_running_stats(self.aspp) # freeze BN def _init_decoder(self, num_classes): self._aff = PAMR(self.cfg.PAMR_ITER, self.cfg.PAMR_KERNEL) def conv2d(*args, **kwargs): conv = nn.Conv2d(*args, **kwargs) self.from_scratch_layers.append(conv) torch.nn.init.kaiming_normal_(conv.weight) return conv def bnorm(*args, **kwargs): bn = self.NormLayer(*args, **kwargs) self.from_scratch_layers.append(bn) if not bn.weight is None: bn.weight.data.fill_(1) bn.bias.data.zero_() return bn # pre-processing for shallow features self.shallow_mask = GCI(self.NormLayer) self.from_scratch_layers += self.shallow_mask.from_scratch_layers # Stochastic Gate self.sg = StochasticGate() self.fc8_skip = nn.Sequential(conv2d(256, 48, 1, bias=False), bnorm(48), nn.ReLU()) self.fc8_x = nn.Sequential(conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), bnorm(256), nn.ReLU()) # decoder self.last_conv = nn.Sequential(conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), bnorm(256), nn.ReLU(), nn.Dropout(0.5), conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), bnorm(256), nn.ReLU(), nn.Dropout(0.1), conv2d(256, num_classes - 1, kernel_size=1, stride=1)) def run_pamr(self, im, mask): im = F.interpolate(im, mask.size()[-2:], mode="bilinear", align_corners=True) masks_dec = self._aff(im, mask) return masks_dec def forward_backbone(self, x): self._backbone = super().forward_as_dict(x) return self._backbone['conv6'] def forward(self, y, y_raw=None, labels=None): test_mode = y_raw is None and labels is None # 1. backbone pass x = self.forward_backbone(y) # 2. ASPP modules x = self.aspp(x) # # 3. merging deep and shallow features # # 3.1 skip connection for deep features x2_x = self.fc8_skip(self._backbone['conv3']) x_up = rescale_as(x, x2_x) x = self.fc8_x(torch.cat([x_up, x2_x], 1)) # 3.2 deep feature context for shallow features x2 = self.shallow_mask(self._backbone['conv3'], x) # 3.3 stochastically merging the masks x = self.sg(x, x2, alpha_rate=self.cfg.SG_PSI) # 4. final convs to get the masks x = self.last_conv(x) # # 5. Finalising the masks and scores # # constant BG scores bg = torch.ones_like(x[:, :1]) x = torch.cat([bg, x], 1) bs, c, h, w = x.size() masks = F.softmax(x, dim=1) # reshaping features = x.view(bs, c, -1) masks_ = masks.view(bs, c, -1) # classification loss cls_1 = (features * masks_).sum(-1) / (1.0 + masks_.sum(-1)) # focal penalty loss cls_2 = focal_loss(masks_.mean(-1), \ p=self.cfg.FOCAL_P, \ c=self.cfg.FOCAL_LAMBDA) # adding the losses together cls = cls_1[:, 1:] + cls_2[:, 1:] if test_mode: # if in test mode, not mask # cleaning is performed return cls, rescale_as(masks, y) self._mask_logits = x # foreground stats masks_ = masks_[:, 1:] cls_fg = (masks_.mean(-1) * labels).sum(-1) / labels.sum(-1) # mask refinement with PAMR masks_dec = self.run_pamr(y_raw, masks.detach()) # upscale the masks & clean masks = self._rescale_and_clean(masks, y, labels) masks_dec = self._rescale_and_clean(masks_dec, y, labels) # create pseudo GT pseudo_gt = pseudo_gtmask(masks_dec).detach() loss_mask = balanced_mask_loss_ce(self._mask_logits, pseudo_gt, labels) return cls, cls_fg, {"cam": masks, "dec": masks_dec}, self._mask_logits, pseudo_gt, loss_mask def _rescale_and_clean(self, masks, image, labels): """Rescale to fit the image size and remove any masks of labels that are not present""" masks = F.interpolate(masks, size=image.size()[-2:], mode='bilinear', align_corners=True) masks[:, 1:] *= labels[:, :, None, None].type_as(masks) return masks