예제 #1
0
    def new_encode_decode_encoder_encoder(img,
                                          img_metas,
                                          extract_layer_ids=[],
                                          use_resize=True):
        """Encode images with backbone and decode into a semantic segmentation
        map of the same size as input."""
        x = obj.extract_feat(img, extract_layer_ids=extract_layer_ids)
        if len(extract_layer_ids) > 0:
            x, feats = x
        out = obj._decode_head_forward_test(x, img_metas)
        if use_resize:
            out = resize(input=out,
                         size=img.shape[2:],
                         mode='bilinear',
                         align_corners=obj.align_corners)

        if hasattr(obj, 'auxiliary_head'):
            out2 = obj._auxiliary_head_forward_test(x, img_metas)
            out2 = resize(input=out2,
                          size=img.shape[2:],
                          mode='bilinear',
                          align_corners=obj.align_corners)
            if len(extract_layer_ids) > 0:
                return out, out2, feats
            else:
                return out, out2, None

        if len(extract_layer_ids) > 0:
            return out, feats
        else:
            return out, None
예제 #2
0
    def forward(self, x):
        output = []

        # sub 1
        output.append(self.conv_sub1(x))

        # sub 2
        x = resize(
            x,
            scale_factor=0.5,
            mode='bilinear',
            align_corners=self.align_corners)
        x = self.backbone.stem(x)
        x = self.backbone.maxpool(x)
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        output.append(self.conv_sub2(x))

        # sub 4
        x = resize(
            x,
            scale_factor=0.5,
            mode='bilinear',
            align_corners=self.align_corners)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        psp_outs = self.psp_modules(x) + [x]
        psp_outs = torch.cat(psp_outs, dim=1)
        x = self.psp_bottleneck(psp_outs)

        output.append(self.conv_sub4(x))

        return output
예제 #3
0
    def forward(self, inputs):
        """Forward function."""
        x = self._transform_inputs(inputs)
        # print(x.size())

        aspp_outs = [
            resize(
                self.image_pool(x),
                size=x.size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
        ]

        aspp_outs.extend(self.aspp_modules(x))
        aspp_outs = torch.cat(aspp_outs, dim=1)
        output = self.bottleneck(aspp_outs)
        if self.c1_bottleneck is not None:
            c1_output = self.c1_bottleneck(inputs[0])
            output = resize(
                input=output,
                size=c1_output.shape[2:],
                mode='bilinear',
                align_corners=self.align_corners)
            output = torch.cat([output, c1_output], dim=1)
        output = self.sep_bottleneck(output)
        output = self.cls_seg(output)
        # print(a)
        return output
예제 #4
0
    def forward(self, x):
        outs = list(self.backbone(x))
        avg = F.adaptive_avg_pool2d(outs[-1], 1)
        avg_feat = self.conv_avg(avg)

        feature_up = resize(avg_feat,
                            size=outs[-1].shape[2:],
                            mode=self.upsample_mode,
                            align_corners=self.align_corners)
        arms_out = []
        for i in range(len(self.arms)):
            x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
            feature_up = resize(x_arm,
                                size=outs[len(outs) - 1 - i - 1].shape[2:],
                                mode=self.upsample_mode,
                                align_corners=self.align_corners)
            feature_up = self.convs[i](feature_up)
            arms_out.append(feature_up)

        feat_fuse = self.ffm(outs[0], arms_out[1])

        # The `outputs` has four feature maps.
        # `outs[0]` is outputted for `STDCHead` auxiliary head.
        # Two feature maps of `arms_out` are outputted for auxiliary head.
        # `feat_fuse` is outputted for decoder head.
        outputs = [outs[0]] + list(arms_out) + [feat_fuse]
        return tuple(outputs)
예제 #5
0
    def forward(self, inputs):
        assert len(inputs) == len(self.in_channels)

        # build laterals
        laterals = [
            lateral_conv(inputs[i + self.start_level])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
            #  it cannot co-exist with `size` in `F.interpolate`.
            if 'scale_factor' in self.upsample_cfg:
                laterals[i - 1] = laterals[i - 1] + resize(
                    laterals[i], **self.upsample_cfg)
            else:
                prev_shape = laterals[i - 1].shape[2:]
                laterals[i - 1] = laterals[i - 1] + resize(
                    laterals[i], size=prev_shape, **self.upsample_cfg)

        # build outputs
        # part 1: from original levels
        outs = [
            self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
        ]
        # part 2: add extra levels
        if self.num_outs > len(outs):
            # use max pool to get more levels on top of outputs
            # (e.g., Faster R-CNN, Mask R-CNN)
            if not self.add_extra_convs:
                for i in range(self.num_outs - used_backbone_levels):
                    outs.append(F.max_pool2d(outs[-1], 1, stride=2))
            # add conv layers on top of original feature maps (RetinaNet)
            else:
                if self.add_extra_convs == 'on_input':
                    extra_source = inputs[self.backbone_end_level - 1]
                elif self.add_extra_convs == 'on_lateral':
                    extra_source = laterals[-1]
                elif self.add_extra_convs == 'on_output':
                    extra_source = outs[-1]
                else:
                    raise NotImplementedError
                outs.append(self.fpn_convs[used_backbone_levels](extra_source))
                for i in range(used_backbone_levels + 1, self.num_outs):
                    if self.relu_before_extra_convs:
                        outs.append(self.fpn_convs[i](F.relu(outs[-1])))
                    else:
                        outs.append(self.fpn_convs[i](outs[-1]))
        return tuple(outs)
예제 #6
0
    def forward(self, x):
        x_4, x_8, x_16, x_32 = self.backbone(x)
        x_gap = self.gap_conv(x_32)

        x_32_arm = self.arm32(x_32)
        x_32_sum = x_32_arm + x_gap
        x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest')
        x_32_up = self.conv_head32(x_32_up)

        x_16_arm = self.arm16(x_16)
        x_16_sum = x_16_arm + x_32_up
        x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest')
        x_16_up = self.conv_head16(x_16_up)

        return x_16_up, x_32_up
예제 #7
0
    def _forward_feature(self, inputs):
        """Forward function for feature maps before classifying each pixel with
        ``self.cls_seg`` fc.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            feats (Tensor): A tensor of shape (batch_size, self.channels,
                H, W) which is feature map for last layer of decoder head.
        """
        inputs = self._transform_inputs(inputs)

        # build laterals
        laterals = [
            lateral_conv(inputs[i])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + resize(
                laterals[i],
                size=prev_shape,
                mode='bilinear',
                align_corners=self.align_corners)

        # build outputs
        fpn_outs = [
            self.fpn_convs[i](laterals[i])
            for i in range(used_backbone_levels - 1)
        ]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = resize(
                fpn_outs[i],
                size=fpn_outs[0].shape[2:],
                mode='bilinear',
                align_corners=self.align_corners)
        fpn_outs = torch.cat(fpn_outs, dim=1)
        feats = self.fpn_bottleneck(fpn_outs)
        return feats
예제 #8
0
    def contrastive_losses(self, seg_logits, gt_semantic_seg, seg_logits1,
                           seg_label, img_metas):
        """Compute pixel-wise contrastive loss."""
        loss = dict()
        seg_logit = resize(input=seg_logits,
                           size=gt_semantic_seg.shape[2:],
                           mode='bilinear',
                           align_corners=self.align_corners)
        for i, logit in enumerate(seg_logits1):
            seg_logits1[i] = resize(input=logit,
                                    size=img_metas[0]['img_shape'][:2],
                                    mode='bilinear',
                                    align_corners=self.align_corners)

        for i, label in enumerate(seg_label):
            seg_label[i] = resize(input=label,
                                  size=img_metas[0]['img_shape'][:2],
                                  mode='bilinear',
                                  align_corners=self.align_corners)

        if self.sampler is not None:
            seg_weight = self.sampler.sample(seg_logit, gt_semantic_seg)
        else:
            seg_weight = None
        gt_semantic_seg = gt_semantic_seg.squeeze(1)
        loss['loss_seg'] = self.loss_decode(seg_logits1,
                                            seg_label,
                                            seg_logit,
                                            gt_semantic_seg,
                                            img_metas,
                                            weight=seg_weight,
                                            ignore_index=self.ignore_index)

        # loss['loss_seg'] = self.loss_decode1(
        #                     seg_logits1,
        #                     seg_label,
        #                     img_metas,
        #                     weight=None,
        #                     ignore_index=self.ignore_index)
        # loss['loss_seg'] = self.loss_decode1(
        #     seg_logits1,
        #     seg_label,
        #     img_metas,
        #     weight=None,
        #     ignore_index=self.ignore_index)
        loss['acc_seg'] = accuracy(seg_logit, gt_semantic_seg)

        return loss
    def forward(self, inputs):
        """Forward function."""
        assert len(inputs) == len(self.in_channels), 'Length of inputs must \
                                           be the same with self.in_channels!'

        feats = [
            self.conv_layers[i - self.start_level](inputs[i])
            for i in range(self.start_level, self.backbone_end_level)
        ]

        h, w = feats[0].shape[2:]
        for i in range(1, len(feats)):
            feats[i] = resize(feats[i],
                              size=(h, w),
                              mode='bilinear',
                              align_corners=self.align_corners)

        feat = torch.cat(feats, dim=1)
        concat_feat = torch.cat([
            self.dilation_layers[i](feat) for i in range(len(self.dilations))
        ],
                                dim=1)

        outs = []

        # Default: outs[2] is the output of JPU for decoder head, outs[1] is
        # the feature map from backbone for auxiliary head. Additionally,
        # outs[0] can also be used for auxiliary head.
        for i in range(self.start_level, self.backbone_end_level - 1):
            outs.append(inputs[i])
        outs.append(concat_feat)
        return tuple(outs)
예제 #10
0
    def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
        """Resize pos_embed weights.

        Resize pos_embed using bicubic interpolate method.
        Args:
            pos_embed (torch.Tensor): Position embedding weights.
            input_shpae (tuple): Tuple for (downsampled input image height,
                downsampled input image width).
            pos_shape (tuple): The resolution of downsampled origin training
                image.
            mode (str): Algorithm used for upsampling:
                ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
                ``'trilinear'``. Default: ``'nearest'``
        Return:
            torch.Tensor: The resized pos_embed of shape [B, L_new, C]
        """
        assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
        pos_h, pos_w = pos_shape
        cls_token_weight = pos_embed[:, 0]
        pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
        pos_embed_weight = pos_embed_weight.reshape(
            1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
        pos_embed_weight = resize(pos_embed_weight,
                                  size=input_shpae,
                                  align_corners=False,
                                  mode=mode)
        cls_token_weight = cls_token_weight.unsqueeze(1)
        pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
        pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
        return pos_embed
예제 #11
0
 def simple_test(self, img: torch.Tensor, img_meta: Iterable,
                 **kwargs) -> list:
     if not self.is_cuda_available:
         img = img.detach().cpu()
     elif self.device_id >= 0:
         img = img.cuda(self.device_id)
     device_type = img.device.type
     self.io_binding.bind_input(name='input',
                                device_type=device_type,
                                device_id=self.device_id,
                                element_type=np.float32,
                                shape=img.shape,
                                buffer_ptr=img.data_ptr())
     self.sess.run_with_iobinding(self.io_binding)
     seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
     # whole might support dynamic reshape
     ori_shape = img_meta[0]['ori_shape']
     if not (ori_shape[0] == seg_pred.shape[-2]
             and ori_shape[1] == seg_pred.shape[-1]):
         seg_pred = torch.from_numpy(seg_pred).float()
         seg_pred = resize(seg_pred,
                           size=tuple(ori_shape[:2]),
                           mode='nearest')
         seg_pred = seg_pred.long().detach().cpu().numpy()
     seg_pred = seg_pred[0]
     seg_pred = list(seg_pred)
     return seg_pred
    def losses(self, seg_logit, seg_label):
        """Compute segmentation loss."""
        loss = dict()
        seg_logit = resize(
            input=seg_logit,
            size=seg_label.shape[2:],
            mode='bilinear',
            align_corners=self.align_corners)
        if self.sampler is not None:
            seg_weight = self.sampler.sample(seg_logit, seg_label)
        else:
            seg_weight = None
        seg_label = seg_label.squeeze(1)
        for loss_decode in self.loss_decode:
            if loss_decode.loss_name not in loss:
                loss[loss_decode.loss_name] = loss_decode(
                    seg_logit,
                    seg_label,
                    weight=seg_weight,
                    ignore_index=self.ignore_index)
            else:
                loss[loss_decode.loss_name] += loss_decode(
                    seg_logit,
                    seg_label,
                    weight=seg_weight,
                    ignore_index=self.ignore_index)

        loss['acc_seg'] = accuracy(seg_logit, seg_label)
        return loss
    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """

        if self.input_transform == 'resize_concat':
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(input=x,
                       size=inputs[0].shape[2:],
                       mode='bilinear',
                       align_corners=self.align_corners) for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == 'multiple_select':
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs
예제 #14
0
    def forward(self, x):
        """Forward function."""
        pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
        # [batch_size, channels, h, w]
        x = self.input_redu_conv(x)
        # [batch_size, channels, pool_scale, pool_scale]
        pooled_x = self.pooled_redu_conv(pooled_x)
        batch_size = x.size(0)
        # [batch_size, pool_scale * pool_scale, channels]
        pooled_x = pooled_x.view(batch_size, self.channels,
                                 -1).permute(0, 2, 1).contiguous()
        # [batch_size, h * w, pool_scale * pool_scale]
        affinity_matrix = self.gla(x + resize(
            self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
                                   ).permute(0, 2, 3, 1).reshape(
                                       batch_size, -1, self.pool_scale**2)
        affinity_matrix = F.sigmoid(affinity_matrix)
        # [batch_size, h * w, channels]
        z_out = torch.matmul(affinity_matrix, pooled_x)
        # [batch_size, channels, h * w]
        z_out = z_out.permute(0, 2, 1).contiguous()
        # [batch_size, channels, h, w]
        z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
        z_out = self.residual_conv(z_out)
        z_out = F.relu(z_out + x)
        if self.fusion:
            z_out = self.fusion_conv(z_out)

        return z_out
예제 #15
0
 def forward(self, x_d, x_s):
     detail_dwconv = self.detail_dwconv(x_d)
     detail_down = self.detail_down(x_d)
     semantic_conv = self.semantic_conv(x_s)
     semantic_dwconv = self.semantic_dwconv(x_s)
     semantic_conv = resize(input=semantic_conv,
                            size=detail_dwconv.shape[2:],
                            mode='bilinear',
                            align_corners=self.align_corners)
     fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
     fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
     fuse_2 = resize(input=fuse_2,
                     size=fuse_1.shape[2:],
                     mode='bilinear',
                     align_corners=self.align_corners)
     output = self.conv(fuse_1 + fuse_2)
     return output
예제 #16
0
    def forward(self, query_feats, key_feats):
        """Forward function."""
        context = super(ObjectAttentionBlock,
                        self).forward(query_feats, key_feats)
        output = self.bottleneck(torch.cat([context, query_feats], dim=1))
        if self.query_downsample is not None:
            output = resize(query_feats)

        return output
예제 #17
0
 def forward(self, *inputs):
     x = inputs[0]
     if len(inputs) == 2:
         if x.shape != inputs[1].shape:
             res = resize(inputs[1],
                          size=(x.shape[2], x.shape[3]),
                          mode='bilinear',
                          align_corners=False)
         else:
             res = inputs[1]
         x = x + self.res_conv_unit1(res)
     x = self.res_conv_unit2(x)
     x = resize(x,
                scale_factor=2,
                mode='bilinear',
                align_corners=self.align_corners)
     x = self.project(x)
     return x
예제 #18
0
 def encode_decode(self, img, img_metas):
     """Encode images with backbone and decode into a semantic segmentation
     map of the same size as input."""
     x = self.extract_feat(img)
     out = self._decode_head_forward_test(x, img_metas)
     out = resize(input=out,
                  size=img.shape[2:],
                  mode='bilinear',
                  align_corners=self.align_corners)
     return out
예제 #19
0
 def forward(self, input):
     conv_out = self.conv(input)
     pool_out = self.pool(input)
     pool_out = resize(input=pool_out,
                       size=conv_out.size()[2:],
                       mode='bilinear',
                       align_corners=False)
     output = torch.cat([conv_out, pool_out], 1)
     output = self.bn(output)
     output = self.act(output)
     return output
예제 #20
0
 def forward(self, x):
     """Forward function."""
     ppm_outs = []
     for ppm in self:
         ppm_out = ppm(x)
         upsampled_ppm_out = resize(ppm_out,
                                    size=x.size()[2:],
                                    mode='bilinear',
                                    align_corners=self.align_corners)
         ppm_outs.append(upsampled_ppm_out)
     return ppm_outs
    def slide_inference(self, img, img_meta, rescale):
        """Inference by sliding-window with overlap.

        If h_crop > h_img or w_crop > w_img, the small patch will be used to
        decode without padding.
        """

        h_stride, w_stride = self.test_cfg.stride
        h_crop, w_crop = self.test_cfg.crop_size
        batch_size, _, h_img, w_img = img.size()
        num_classes = self.num_classes
        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
        # preds = [img.new_zeros((batch_size, num_classes, h_img, w_img)),
        preds = [
            img.new_zeros((batch_size, num_classes, h_img, w_img)),
            img.new_zeros((batch_size, num_classes, h_img, w_img)),
            img.new_zeros((batch_size, num_classes, h_img, w_img))
        ]

        count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
        for h_idx in range(h_grids):
            for w_idx in range(w_grids):
                y1 = h_idx * h_stride
                x1 = w_idx * w_stride
                y2 = min(y1 + h_crop, h_img)
                x2 = min(x1 + w_crop, w_img)
                y1 = max(y2 - h_crop, 0)
                x1 = max(x2 - w_crop, 0)
                crop_img = img[:, :, y1:y2, x1:x2]
                crop_seg_logits = self.encode_decode(crop_img, img_meta)
                for idx, crop_seg_logit in enumerate(crop_seg_logits):
                    preds[idx] += F.pad(
                        crop_seg_logit,
                        (int(x1), int(preds[idx].shape[3] - x2), int(y1),
                         int(preds[idx].shape[2] - y2)))

                count_mat[:, :, y1:y2, x1:x2] += 1
        assert (count_mat == 0).sum() == 0
        if torch.onnx.is_in_onnx_export():
            # cast count_mat to constant while exporting to ONNX
            count_mat = torch.from_numpy(
                count_mat.cpu().detach().numpy()).to(device=img.device)
        preds = [pred / count_mat for pred in preds]
        if rescale:
            preds = [
                resize(pred,
                       size=img_meta[0]['ori_shape'][:2],
                       mode='bilinear',
                       align_corners=self.align_corners,
                       warning=False) for pred in preds
            ]
        # A-->[A, ..]
        return preds
예제 #22
0
 def head(self, x, inputs):
     aspp_outs = [
         resize(self.image_pool(x),
                size=x.size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
     ]
     aspp_outs.extend(self.aspp_modules(x))
     aspp_outs = torch.cat(aspp_outs, dim=1)
     output = self.bottleneck(aspp_outs)
     if self.c1_bottleneck is not None:
         c1_output = self.c1_bottleneck(inputs[0])
         output = resize(input=output,
                         size=c1_output.shape[2:],
                         mode='bilinear',
                         align_corners=self.align_corners)
         output = torch.cat([output, c1_output], dim=1)
     output = self.sep_bottleneck(output)
     output = self.cls_seg(output)
     return output
예제 #23
0
    def forward(self, higher_res_feature, lower_res_feature):
        lower_res_feature = resize(lower_res_feature,
                                   size=higher_res_feature.size()[2:],
                                   mode='bilinear',
                                   align_corners=self.align_corners)
        lower_res_feature = self.dwconv(lower_res_feature)
        lower_res_feature = self.conv_lower_res(lower_res_feature)

        higher_res_feature = self.conv_higher_res(higher_res_feature)
        out = higher_res_feature + lower_res_feature
        return self.relu(out)
예제 #24
0
    def forward(self, inputs):
        """Forward function."""

        inputs = self._transform_inputs(inputs)

        # build laterals
        laterals = [
            lateral_conv(inputs[i])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] += resize(
                laterals[i],
                size=prev_shape,
                mode='bilinear',
                align_corners=self.align_corners)

        # build outputs
        fpn_outs = [
            self.fpn_convs[i](laterals[i])
            for i in range(used_backbone_levels - 1)
        ]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = resize(
                fpn_outs[i],
                size=fpn_outs[0].shape[2:],
                mode='bilinear',
                align_corners=self.align_corners)
        fpn_outs = torch.cat(fpn_outs, dim=1)
        output = self.fpn_bottleneck(fpn_outs)
        output = self.cls_seg(output)
        return output
    def forward(self, inputs):
        """Forward function."""
        inputs = self._transform_inputs(inputs)

        x = inputs[-1]

        x = self.aspp_conv(x) * resize(self.image_pool(x),
                                       size=x.size()[2:],
                                       mode='bilinear',
                                       align_corners=self.align_corners)
        x = self.conv_up_input(x)

        for i in range(len(self.branch_channels) - 1, -1, -1):
            x = resize(x,
                       size=inputs[i].size()[2:],
                       mode='bilinear',
                       align_corners=self.align_corners)
            x = torch.cat([x, self.convs[i](inputs[i])], 1)
            x = self.conv_ups[i](x)

        return self.cls_seg(x)
예제 #26
0
    def whole_inference(self, img, img_meta, rescale):
        """Inference with full image."""

        seg_logit = self.encode_decode(img, img_meta)
        if rescale:
            seg_logit = resize(seg_logit,
                               size=img_meta[0]['ori_shape'][:2],
                               mode='bilinear',
                               align_corners=self.align_corners,
                               warning=False)

        return seg_logit
 def encode_decode(self, img, img_metas):
     """Encode images with backbone and decode into a semantic segmentation
     map of the same size as input."""
     x = self.extract_feat_l(img)
     out = self.decode_head_l[0].forward_test(x, img_metas, self.test_cfg)
     for i in range(1, self.num_stages):
         out = self.decode_head_l[i].forward_test(x, out, img_metas,
                                                  self.test_cfg)
     out = resize(input=out,
                  size=img.shape[2:],
                  mode='bilinear',
                  align_corners=self.align_corners)
     return out
예제 #28
0
 def forward(self, x_low, x_high):
     x_low = resize(x_low,
                    size=x_high.size()[2:],
                    mode='bilinear',
                    align_corners=self.align_corners)
     # Note: Different from original paper, `x_low` is underwent
     # `self.conv_low` rather than another 1x1 conv classifier
     #  before being used for auxiliary head.
     x_low = self.conv_low(x_low)
     x_high = self.conv_high(x_high)
     x = x_low + x_high
     x = F.relu(x, inplace=True)
     return x, x_low
 def forward(self, inputs):
     """Forward function."""
     x = self._transform_inputs(inputs)
     aspp_outs = [
         resize(self.image_pool(x),
                size=x.size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
     ]
     aspp_outs.extend(self.aspp_modules(x))
     aspp_outs = torch.cat(aspp_outs, dim=1)
     output = self.bottleneck(aspp_outs)
     output = self.cls_seg(output)
     return output
예제 #30
0
    def forward(self, inputs, prev_output):
        feats = self.bottleneck(inputs[self.feature_key])
        """Forward function."""
        cur_prob = resize(input=prev_output,
                          size=feats.shape[2:],
                          mode='bilinear',
                          align_corners=self.align_corners)
        context = self.spatial_gather_module(feats, cur_prob)
        output = self.object_context_block(feats, context)

        # build decoder
        for i in range(self.decoder_stage):
            low_level_key = self.low_level_key[i]
            low_level_feats = self.project[i](inputs[low_level_key])
            output = resize(output,
                            size=low_level_feats.shape[2:],
                            mode='bilinear',
                            align_corners=self.align_corners)
            output = torch.cat([output, low_level_feats], dim=1)
            output = self.fuse[i](output)

        output = self.cls_seg(output)
        return output