Example #1
0
        def _init_aspp(self):
            self.aspp = ASPP(self.fan_out(), 8, self.NormLayer)

            for m in self.aspp.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, self.NormLayer):
                    self.from_scratch_layers.append(m)

            self._fix_running_stats(self.aspp) # freeze BN
Example #2
0
    class SoftMaxAE(backbone):

        def __init__(self, config, pre_weights=None, num_classes=21, dropout=True):
            super().__init__()

            self.cfg = config
            self.num_classes = num_classes

            self._init_weights(pre_weights) # initialise backbone weights
            self._fix_running_stats(self, fix_params=True) # freeze backbone BNs

            # Decoder
            self._init_aspp()
            self._init_decoder(num_classes)

            self._backbone = None
            self._mask_logits = None

        def _init_aspp(self):
            self.aspp = ASPP(self.fan_out(), 8, self.NormLayer)

            for m in self.aspp.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, self.NormLayer):
                    self.from_scratch_layers.append(m)

            self._fix_running_stats(self.aspp) # freeze BN

        def _init_decoder(self, num_classes):

            self._aff = PAMR(self.cfg.PAMR_ITER, self.cfg.PAMR_KERNEL)

            def conv2d(*args, **kwargs):
                conv = nn.Conv2d(*args, **kwargs)
                self.from_scratch_layers.append(conv)
                torch.nn.init.kaiming_normal_(conv.weight)
                return conv

            def bnorm(*args, **kwargs):
                bn = self.NormLayer(*args, **kwargs)
                self.from_scratch_layers.append(bn)
                if not bn.weight is None:
                    bn.weight.data.fill_(1)
                    bn.bias.data.zero_()
                return bn

            # pre-processing for shallow features
            self.shallow_mask = GCI(self.NormLayer)
            self.from_scratch_layers += self.shallow_mask.from_scratch_layers

            # Stochastic Gate
            self.sg = StochasticGate()
            self.fc8_skip = nn.Sequential(conv2d(256, 48, 1, bias=False), bnorm(48), nn.ReLU())
            self.fc8_x = nn.Sequential(conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       bnorm(256), nn.ReLU())

            # decoder
            self.last_conv = nn.Sequential(conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                           bnorm(256), nn.ReLU(),
                                           nn.Dropout(0.5),
                                           conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                           bnorm(256), nn.ReLU(),
                                           nn.Dropout(0.1),
                                           conv2d(256, num_classes - 1, kernel_size=1, stride=1))

        def run_pamr(self, im, mask):
            im = F.interpolate(im, mask.size()[-2:], mode="bilinear", align_corners=True)
            masks_dec = self._aff(im, mask)
            return masks_dec

        def forward_backbone(self, x):
            self._backbone = super().forward_as_dict(x)
            return self._backbone['conv6']

        def forward(self, y, y_raw=None, labels=None):
            test_mode = y_raw is None and labels is None

            # 1. backbone pass
            x = self.forward_backbone(y)

            # 2. ASPP modules
            x = self.aspp(x)

            #
            # 3. merging deep and shallow features
            #

            # 3.1 skip connection for deep features
            x2_x = self.fc8_skip(self._backbone['conv3'])
            x_up = rescale_as(x, x2_x)
            x = self.fc8_x(torch.cat([x_up, x2_x], 1))

            # 3.2 deep feature context for shallow features
            x2 = self.shallow_mask(self._backbone['conv3'], x)

            # 3.3 stochastically merging the masks
            x = self.sg(x, x2, alpha_rate=self.cfg.SG_PSI)

            # 4. final convs to get the masks
            x = self.last_conv(x)

            #
            # 5. Finalising the masks and scores
            #

            # constant BG scores
            bg = torch.ones_like(x[:, :1])
            x = torch.cat([bg, x], 1)

            bs, c, h, w = x.size()

            masks = F.softmax(x, dim=1)

            # reshaping
            features = x.view(bs, c, -1)
            masks_ = masks.view(bs, c, -1)

            # classification loss
            cls_1 = (features * masks_).sum(-1) / (1.0 + masks_.sum(-1))

            # focal penalty loss
            cls_2 = focal_loss(masks_.mean(-1), \
                               p=self.cfg.FOCAL_P, \
                               c=self.cfg.FOCAL_LAMBDA)

            # adding the losses together
            cls = cls_1[:, 1:] + cls_2[:, 1:]

            if test_mode:
                # if in test mode, not mask
                # cleaning is performed
                return cls, rescale_as(masks, y)

            self._mask_logits = x

            # foreground stats
            masks_ = masks_[:, 1:]
            cls_fg = (masks_.mean(-1) * labels).sum(-1) / labels.sum(-1)

            # mask refinement with PAMR
            masks_dec = self.run_pamr(y_raw, masks.detach())

            # upscale the masks & clean
            masks = self._rescale_and_clean(masks, y, labels)
            masks_dec = self._rescale_and_clean(masks_dec, y, labels)

            # create pseudo GT
            pseudo_gt = pseudo_gtmask(masks_dec).detach()
            loss_mask = balanced_mask_loss_ce(self._mask_logits, pseudo_gt, labels)

            return cls, cls_fg, {"cam": masks, "dec": masks_dec}, self._mask_logits, pseudo_gt, loss_mask

        def _rescale_and_clean(self, masks, image, labels):
            """Rescale to fit the image size and remove any masks
            of labels that are not present"""
            masks = F.interpolate(masks, size=image.size()[-2:], mode='bilinear', align_corners=True)
            masks[:, 1:] *= labels[:, :, None, None].type_as(masks)
            return masks