예제 #1
0
class _PUPHead(nn.HybridBlock):
    def __init__(self, nclass, aux, norm_layer=nn.BatchNorm, norm_kwargs=None):
        super(_PUPHead, self).__init__()
        self.aux = aux
        with self.name_scope():
            self.conv0 = ConvModule2d(256,
                                      3,
                                      1,
                                      1,
                                      norm_layer=norm_layer,
                                      norm_kwargs=norm_kwargs)
            self.conv1 = ConvModule2d(256,
                                      3,
                                      1,
                                      1,
                                      norm_layer=norm_layer,
                                      norm_kwargs=norm_kwargs)
            self.conv2 = ConvModule2d(256,
                                      3,
                                      1,
                                      1,
                                      norm_layer=norm_layer,
                                      norm_kwargs=norm_kwargs)
            self.conv3 = ConvModule2d(256,
                                      3,
                                      1,
                                      1,
                                      norm_layer=norm_layer,
                                      norm_kwargs=norm_kwargs)
            self.conv4 = ConvModule2d(nclass,
                                      3,
                                      1,
                                      1,
                                      norm_layer=norm_layer,
                                      norm_kwargs=norm_kwargs)
            if self.aux:
                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(_SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs))

    def hybrid_forward(self, F, x, *args, **kwargs):
        outputs = []
        out = self.conv0(x)
        out = F.contrib.BilinearResize2D(out, scale_height=2., scale_width=2.)
        out = self.conv1(out)
        out = F.contrib.BilinearResize2D(out, scale_height=2., scale_width=2.)
        out = self.conv2(out)
        out = F.contrib.BilinearResize2D(out, scale_height=2., scale_width=2.)
        out = self.conv4(self.conv3(out))
        outputs.append(out)
        if self.aux:
            aux_outs = self.aux_head(x, *args)
            outputs = outputs + aux_outs
        return tuple(outputs)
예제 #2
0
 def __init__(self, nclass, aux, norm_layer=nn.BatchNorm, norm_kwargs=None):
     super(_NaiveHead, self).__init__()
     self.aux = aux
     with self.name_scope():
         self.head = _SegHead(nclass, norm_layer, norm_kwargs)
         if self.aux:
             self.aux_head = HybridConcurrentIsolate()
             self.aux_head.add(_SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs))
예제 #3
0
파일: bisenet.py 프로젝트: BebDong/MXNetSeg
class BiSeNetR(SegBaseResNet):
    def __init__(self,
                 nclass,
                 backbone='resnet18',
                 aux=True,
                 height=None,
                 width=None,
                 base_size=520,
                 crop_size=480,
                 pretrained_base=False,
                 norm_layer=nn.BatchNorm,
                 norm_kwargs=None,
                 **kwargs):
        super(BiSeNetR, self).__init__(nclass,
                                       aux,
                                       backbone,
                                       height,
                                       width,
                                       base_size,
                                       crop_size,
                                       pretrained_base,
                                       dilate=False,
                                       norm_layer=norm_layer,
                                       norm_kwargs=norm_kwargs)
        with self.name_scope():
            self.head = _BiSeNetHead(nclass,
                                     norm_layer=norm_layer,
                                     norm_kwargs=norm_kwargs)
            if self.aux:
                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(
                    FCNHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs),
                    FCNHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs))

    def hybrid_forward(self, F, x, *args, **kwargs):
        _, _, c3, c4 = self.base_forward(x)
        outputs = []
        x = self.head(x, c3, c4)
        outputs.append(x)

        if self.aux:
            aux_outs = self.aux_head(c4, c3)
            outputs = outputs + aux_outs
        outputs = [
            F.contrib.BilinearResize2D(out, **self._up_kwargs)
            for out in outputs
        ]
        return tuple(outputs)
예제 #4
0
 def _build_decoder(decoder, layer_norm_eps):
     if decoder == 'naive':
         out_indices = (10, 15, 20, 24)
         head = _NaiveHead
     elif decoder == 'pup':
         out_indices = (10, 15, 20, 24)
         head = _PUPHead
     else:
         out_indices = (6, 12, 18, 24)
         head = _MLAHead
     out_indices = tuple([i - 1 for i in out_indices])
     layer_norms = HybridConcurrentIsolate()
     for i in range(len(out_indices)):
         layer_norms.add(nn.LayerNorm(epsilon=layer_norm_eps))
     return out_indices, layer_norms, head
예제 #5
0
파일: danet.py 프로젝트: BebDong/MXNetSeg
    def __init__(self, nclass, backbone='resnet50', aux=False, height=None, width=None,
                 base_size=520, crop_size=480, pretrained_base=True, norm_layer=nn.BatchNorm,
                 norm_kwargs=None, **kwargs):
        super(DANet, self).__init__(nclass, aux, backbone, height, width, base_size,
                                    crop_size, pretrained_base, dilate=True,
                                    norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        with self.name_scope():
            self.head = _DANetHead(nclass, self.stage_channels[3], norm_layer, norm_kwargs)
            if self.aux:
                pam_layer = nn.HybridSequential()
                pam_layer.add(nn.Dropout(0.1))
                pam_layer.add(nn.Conv2D(nclass, 1))

                cam_layer = nn.HybridSequential()
                cam_layer.add(nn.Dropout(0.1))
                cam_layer.add(nn.Conv2D(nclass, 1))

                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(pam_layer, cam_layer)
예제 #6
0
class _NaiveHead(nn.HybridBlock):
    def __init__(self, nclass, aux, norm_layer=nn.BatchNorm, norm_kwargs=None):
        super(_NaiveHead, self).__init__()
        self.aux = aux
        with self.name_scope():
            self.head = _SegHead(nclass, norm_layer, norm_kwargs)
            if self.aux:
                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(_SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs))

    def hybrid_forward(self, F, x, *args, **kwargs):
        outputs = []
        out = self.head(x)
        outputs.append(out)
        if self.aux:
            aux_outs = self.aux_head(*args)
            outputs = outputs + aux_outs
        return tuple(outputs)
예제 #7
0
파일: danet.py 프로젝트: BebDong/MXNetSeg
class DANet(SegBaseResNet):
    """
    ResNet based DANet.
    Reference:
        J. Fu et al., “Dual Attention Network for Scene Segmentation,”
        in IEEE Conference on Computer Vision and Pattern Recognition, 2019.
    """

    def __init__(self, nclass, backbone='resnet50', aux=False, height=None, width=None,
                 base_size=520, crop_size=480, pretrained_base=True, norm_layer=nn.BatchNorm,
                 norm_kwargs=None, **kwargs):
        super(DANet, self).__init__(nclass, aux, backbone, height, width, base_size,
                                    crop_size, pretrained_base, dilate=True,
                                    norm_layer=norm_layer, norm_kwargs=norm_kwargs)
        with self.name_scope():
            self.head = _DANetHead(nclass, self.stage_channels[3], norm_layer, norm_kwargs)
            if self.aux:
                pam_layer = nn.HybridSequential()
                pam_layer.add(nn.Dropout(0.1))
                pam_layer.add(nn.Conv2D(nclass, 1))

                cam_layer = nn.HybridSequential()
                cam_layer.add(nn.Dropout(0.1))
                cam_layer.add(nn.Conv2D(nclass, 1))

                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(pam_layer, cam_layer)

    def hybrid_forward(self, F, x, *args, **kwargs):
        _, _, _, c4 = self.base_forward(x)
        outputs = []
        out, pam_out, cam_out = self.head(c4)
        outputs.append(out)

        if self.aux:
            aux_outs = self.aux_head(pam_out, cam_out)
            outputs = outputs + aux_outs

        outputs = [F.contrib.BilinearResize2D(out, **self._up_kwargs) for out in outputs]
        return tuple(outputs)
예제 #8
0
 def __init__(self, nclass, aux, norm_layer=nn.BatchNorm, norm_kwargs=None):
     super(_PUPHead, self).__init__()
     self.aux = aux
     with self.name_scope():
         self.conv0 = ConvModule2d(256,
                                   3,
                                   1,
                                   1,
                                   norm_layer=norm_layer,
                                   norm_kwargs=norm_kwargs)
         self.conv1 = ConvModule2d(256,
                                   3,
                                   1,
                                   1,
                                   norm_layer=norm_layer,
                                   norm_kwargs=norm_kwargs)
         self.conv2 = ConvModule2d(256,
                                   3,
                                   1,
                                   1,
                                   norm_layer=norm_layer,
                                   norm_kwargs=norm_kwargs)
         self.conv3 = ConvModule2d(256,
                                   3,
                                   1,
                                   1,
                                   norm_layer=norm_layer,
                                   norm_kwargs=norm_kwargs)
         self.conv4 = ConvModule2d(nclass,
                                   3,
                                   1,
                                   1,
                                   norm_layer=norm_layer,
                                   norm_kwargs=norm_kwargs)
         if self.aux:
             self.aux_head = HybridConcurrentIsolate()
             self.aux_head.add(_SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs))
예제 #9
0
파일: bisenet.py 프로젝트: BebDong/MXNetSeg
    def __init__(self,
                 nclass,
                 backbone='xception39',
                 aux=True,
                 height=None,
                 width=None,
                 base_size=520,
                 crop_size=480,
                 pretrained_base=False,
                 norm_layer=nn.BatchNorm,
                 norm_kwargs=None,
                 **kwargs):
        super(BiSeNetX, self).__init__(nclass, aux, height, width, base_size,
                                       crop_size)
        assert backbone == 'xception39', 'support only xception39 as the backbone.'
        pretrained = xception39(pretrained_base,
                                norm_layer=norm_layer,
                                norm_kwargs=norm_kwargs)
        with self.name_scope():
            self.conv = pretrained.conv1
            self.max_pool = pretrained.maxpool
            self.layer1 = pretrained.layer1
            self.layer2 = pretrained.layer2
            self.layer3 = pretrained.layer3

            self.head = _BiSeNetHead(nclass,
                                     norm_layer=norm_layer,
                                     norm_kwargs=norm_kwargs)
            if self.aux:
                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(
                    FCNHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs),
                    FCNHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs))
예제 #10
0
 def __init__(self,
              nclass,
              backbone='resnet50',
              aux=True,
              height=None,
              width=None,
              base_size=520,
              crop_size=480,
              pretrained_base=True,
              norm_layer=nn.BatchNorm,
              norm_kwargs=None,
              **kwargs):
     super(AttentionToScale, self).__init__(nclass,
                                            aux,
                                            backbone,
                                            height,
                                            width,
                                            base_size,
                                            crop_size,
                                            pretrained_base,
                                            dilate=True,
                                            norm_layer=norm_layer,
                                            norm_kwargs=norm_kwargs)
     with self.name_scope():
         self.head = _AttentionHead(nclass,
                                    norm_layer=norm_layer,
                                    norm_kwargs=norm_kwargs)
         if self.aux:
             self.aux_head = HybridConcurrentIsolate()
             self.aux_head.add(
                 AuxHead(nclass,
                         norm_layer=norm_layer,
                         norm_kwargs=norm_kwargs),
                 AuxHead(nclass,
                         norm_layer=norm_layer,
                         norm_kwargs=norm_kwargs))
예제 #11
0
파일: bisenet.py 프로젝트: BebDong/MXNetSeg
 def __init__(self,
              nclass,
              backbone='resnet18',
              aux=True,
              height=None,
              width=None,
              base_size=520,
              crop_size=480,
              pretrained_base=False,
              norm_layer=nn.BatchNorm,
              norm_kwargs=None,
              **kwargs):
     super(BiSeNetR, self).__init__(nclass,
                                    aux,
                                    backbone,
                                    height,
                                    width,
                                    base_size,
                                    crop_size,
                                    pretrained_base,
                                    dilate=False,
                                    norm_layer=norm_layer,
                                    norm_kwargs=norm_kwargs)
     with self.name_scope():
         self.head = _BiSeNetHead(nclass,
                                  norm_layer=norm_layer,
                                  norm_kwargs=norm_kwargs)
         if self.aux:
             self.aux_head = HybridConcurrentIsolate()
             self.aux_head.add(
                 FCNHead(nclass,
                         norm_layer=norm_layer,
                         norm_kwargs=norm_kwargs),
                 FCNHead(nclass,
                         norm_layer=norm_layer,
                         norm_kwargs=norm_kwargs))
예제 #12
0
class AttentionToScale(SegBaseResNet):
    """
    ResNet based attention-to-scale model.
    Only support training with two scales of 1.0x and 0.5x.
    Reference: L. C. Chen, Y. Yang, J. Wang, W. Xu, and A. L. Yuille,
        “Attention to Scale: Scale-Aware Semantic Image Segmentation,” in IEEE Conference
         on Computer Vision and Pattern Recognition, 2016, pp. 3640–3649.
    """
    def __init__(self,
                 nclass,
                 backbone='resnet50',
                 aux=True,
                 height=None,
                 width=None,
                 base_size=520,
                 crop_size=480,
                 pretrained_base=True,
                 norm_layer=nn.BatchNorm,
                 norm_kwargs=None,
                 **kwargs):
        super(AttentionToScale, self).__init__(nclass,
                                               aux,
                                               backbone,
                                               height,
                                               width,
                                               base_size,
                                               crop_size,
                                               pretrained_base,
                                               dilate=True,
                                               norm_layer=norm_layer,
                                               norm_kwargs=norm_kwargs)
        with self.name_scope():
            self.head = _AttentionHead(nclass,
                                       norm_layer=norm_layer,
                                       norm_kwargs=norm_kwargs)
            if self.aux:
                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(
                    AuxHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs),
                    AuxHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs))

    def hybrid_forward(self, F, x, *args, **kwargs):
        # 1.0x scale forward
        _, _, _, c4 = self.base_forward(x)
        # 0.5x scale forward
        xh = F.contrib.BilinearResize2D(x,
                                        height=self._up_kwargs['height'] // 2,
                                        width=self._up_kwargs['width'] // 2)
        _, _, _, c4h = self.base_forward(xh)
        # head
        outputs = []
        x = self.head(c4, c4h)
        outputs.append(x)

        if self.aux:
            aux_outs = self.aux_head(c4, c4h)
            outputs = outputs + aux_outs

        outputs = [
            F.contrib.BilinearResize2D(out, **self._up_kwargs)
            for out in outputs
        ]
        return tuple(outputs)
예제 #13
0
파일: bisenet.py 프로젝트: BebDong/MXNetSeg
class BiSeNetX(SegBaseModel):
    def __init__(self,
                 nclass,
                 backbone='xception39',
                 aux=True,
                 height=None,
                 width=None,
                 base_size=520,
                 crop_size=480,
                 pretrained_base=False,
                 norm_layer=nn.BatchNorm,
                 norm_kwargs=None,
                 **kwargs):
        super(BiSeNetX, self).__init__(nclass, aux, height, width, base_size,
                                       crop_size)
        assert backbone == 'xception39', 'support only xception39 as the backbone.'
        pretrained = xception39(pretrained_base,
                                norm_layer=norm_layer,
                                norm_kwargs=norm_kwargs)
        with self.name_scope():
            self.conv = pretrained.conv1
            self.max_pool = pretrained.maxpool
            self.layer1 = pretrained.layer1
            self.layer2 = pretrained.layer2
            self.layer3 = pretrained.layer3

            self.head = _BiSeNetHead(nclass,
                                     norm_layer=norm_layer,
                                     norm_kwargs=norm_kwargs)
            if self.aux:
                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(
                    FCNHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs),
                    FCNHead(nclass,
                            norm_layer=norm_layer,
                            norm_kwargs=norm_kwargs))

    def base_forward(self, x):
        x = self.conv(x)
        x = self.max_pool(x)
        x = self.layer1(x)
        c2 = self.layer2(x)
        c3 = self.layer3(c2)
        return c2, c3

    def hybrid_forward(self, F, x, *args, **kwargs):
        c2, c3 = self.base_forward(x)
        outputs = []
        x = self.head(x, c2, c3)
        outputs.append(x)

        if self.aux:
            aux_outs = self.aux_head(c3, c2)
            outputs = outputs + aux_outs
        outputs = [
            F.contrib.BilinearResize2D(out, **self._up_kwargs)
            for out in outputs
        ]
        return tuple(outputs)
예제 #14
0
 def __init__(self, nclass, aux, norm_layer=nn.BatchNorm, norm_kwargs=None):
     super(_MLAHead, self).__init__()
     self.aux = aux
     with self.name_scope():
         # top-down aggregation
         self.conv1x1_p5 = ConvModule2d(256,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         self.conv1x1_p4 = ConvModule2d(256,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         self.conv1x1_p3 = ConvModule2d(256,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         self.conv1x1_p2 = ConvModule2d(256,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         self.conv3x3_p5 = ConvModule2d(256,
                                        3,
                                        1,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         self.conv3x3_p4 = ConvModule2d(256,
                                        3,
                                        1,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         self.conv3x3_p3 = ConvModule2d(256,
                                        3,
                                        1,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         self.conv3x3_p2 = ConvModule2d(256,
                                        3,
                                        1,
                                        1,
                                        norm_layer=norm_layer,
                                        norm_kwargs=norm_kwargs)
         # segmentation head
         self.head5 = nn.HybridSequential()
         self.head5.add(
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs),
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs))
         self.head4 = nn.HybridSequential()
         self.head4.add(
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs),
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs))
         self.head3 = nn.HybridSequential()
         self.head3.add(
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs),
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs))
         self.head2 = nn.HybridSequential()
         self.head2.add(
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs),
             ConvModule2d(128,
                          3,
                          1,
                          1,
                          norm_layer=norm_layer,
                          norm_kwargs=norm_kwargs))
         self.head = nn.Conv2D(nclass, 1, in_channels=128 * 4)
         if self.aux:
             self.aux_head = HybridConcurrentIsolate()
             self.aux_head.add(_SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs),
                               _SegHead(nclass, norm_layer, norm_kwargs))
예제 #15
0
class _MLAHead(nn.HybridBlock):
    def __init__(self, nclass, aux, norm_layer=nn.BatchNorm, norm_kwargs=None):
        super(_MLAHead, self).__init__()
        self.aux = aux
        with self.name_scope():
            # top-down aggregation
            self.conv1x1_p5 = ConvModule2d(256,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            self.conv1x1_p4 = ConvModule2d(256,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            self.conv1x1_p3 = ConvModule2d(256,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            self.conv1x1_p2 = ConvModule2d(256,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            self.conv3x3_p5 = ConvModule2d(256,
                                           3,
                                           1,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            self.conv3x3_p4 = ConvModule2d(256,
                                           3,
                                           1,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            self.conv3x3_p3 = ConvModule2d(256,
                                           3,
                                           1,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            self.conv3x3_p2 = ConvModule2d(256,
                                           3,
                                           1,
                                           1,
                                           norm_layer=norm_layer,
                                           norm_kwargs=norm_kwargs)
            # segmentation head
            self.head5 = nn.HybridSequential()
            self.head5.add(
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs),
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs))
            self.head4 = nn.HybridSequential()
            self.head4.add(
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs),
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs))
            self.head3 = nn.HybridSequential()
            self.head3.add(
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs),
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs))
            self.head2 = nn.HybridSequential()
            self.head2.add(
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs),
                ConvModule2d(128,
                             3,
                             1,
                             1,
                             norm_layer=norm_layer,
                             norm_kwargs=norm_kwargs))
            self.head = nn.Conv2D(nclass, 1, in_channels=128 * 4)
            if self.aux:
                self.aux_head = HybridConcurrentIsolate()
                self.aux_head.add(_SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs),
                                  _SegHead(nclass, norm_layer, norm_kwargs))

    def hybrid_forward(self, F, x, *args, **kwargs):
        c5 = self.conv1x1_p5(x)
        c4 = self.conv1x1_p4(args[0])
        c3 = self.conv1x1_p3(args[1])
        c2 = self.conv1x1_p2(args[2])

        c4_plus = c5 + c4
        c3_plus = c4_plus + c3
        c2_plus = c3_plus + c2

        p5 = self.head5(self.conv3x3_p5(c5))
        p4 = self.head4(self.conv3x3_p4(c4_plus))
        p3 = self.head3(self.conv3x3_p3(c3_plus))
        p2 = self.head2(self.conv3x3_p2(c2_plus))

        outputs = []
        out = self.head(F.concat(p5, p4, p3, p2, dim=1))
        outputs.append(out)

        if self.aux:
            aux_outs = self.aux_head(x, *args)
            outputs = outputs + aux_outs
        return tuple(outputs)