Exemple #1
0
    def forward(self, x):
        h, w = x.shape[2:]

        # H, W -> H/2, W/2
        x_sub2 = F.interpolate(
            x, size=get_interp_size(x, s_factor=2), mode="bilinear", align_corners=True
        )

        # H/2, W/2 -> H/4, W/4
        # x_sub2 = self.convbnrelu1_1(x_sub2)
        # x_sub2 = self.convbnrelu1_2(x_sub2)
        # x_sub2 = self.convbnrelu1_3(x_sub2)

        # H/4, W/4 -> H/8, W/8
        # x_sub2 = F.max_pool2d(x_sub2, 3, 2, 1)
        x_sub2 = self.layer0(x_sub2)  # resnet conv+maxpool

        # H/8, W/8 -> H/16, W/16
        # x_sub2 = self.res_block2(x_sub2)
        x_sub2 = self.layer1(x_sub2)

        x_sub2 = self.res_block3_conv(x_sub2)
        # H/16, W/16 -> H/32, W/32
        x_sub4 = F.interpolate(
            x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode="bilinear", align_corners=True
        )
        x_sub4 = self.res_block3_identity(x_sub4)

        # x_sub4 = self.res_block4(x_sub4)
        # x_sub4 = self.res_block5(x_sub4)
        x_sub4 = self.layer3(x_sub4)
        x_sub4 = self.layer4(x_sub4)

        x_sub4 = self.pyramid_pooling(x_sub4)
        x_sub4 = self.conv5_4_k1(x_sub4)

        x_sub1 = self.convbnrelu1_sub1(x)
        x_sub1 = self.convbnrelu2_sub1(x_sub1)
        x_sub1 = self.convbnrelu3_sub1(x_sub1)

        x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2)
        x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1)

        x_sub12 = F.interpolate(
            x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True
        )
        x_sub4 = self.res_block3_identity(x_sub4)
        sub124_cls = self.classification(x_sub12)

        if self.training:
            return (sub124_cls, sub24_cls, sub4_cls)
        else:
            sub124_cls = F.interpolate(
                sub124_cls,
                size=get_interp_size(sub124_cls, z_factor=4),
                mode="bilinear",
                align_corners=True,
            )
            return sub124_cls
Exemple #2
0
    def forward(self, x):
        h, w = x.shape[2:]

        # H, W -> H/2, W/2
        x_sub2 = F.interpolate(
            x, size=get_interp_size(x, s_factor=2), mode="bilinear", align_corners=True
        )

        # H/2, W/2 -> H/4, W/4
        x_sub2 = self.convbnrelu1_1(x_sub2)
        x_sub2 = self.convbnrelu1_2(x_sub2)
        x_sub2 = self.convbnrelu1_3(x_sub2)

        # H/4, W/4 -> H/8, W/8
        x_sub2 = F.max_pool2d(x_sub2, 3, 2, 1)

        # H/8, W/8 -> H/16, W/16
        x_sub2_ = self.res_block2(x_sub2)
        x_sub2 = self.res_block3_conv(x_sub2_)

        x_sub2_inst = self.res_block3_conv_inst(x_sub2_)
        # H/16, W/16 -> H/32, W/32
        x_sub4 = F.interpolate(
            x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode="bilinear", align_corners=True
        )
        x_sub4 = self.res_block3_identity(x_sub4)

        x_sub4 = self.res_block4(x_sub4)
        x_sub4 = self.res_block5(x_sub4)

        x_sub4_ = self.pyramid_pooling(x_sub4)
        x_sub4 = self.conv5_4_k1(x_sub4_)
        
        x_sub4_inst = self.conv5_4_k1_inst(x_sub4_)

        x_sub1 = self.convbnrelu1_sub1(x)
        x_sub1_ = self.convbnrelu2_sub1(x_sub1)
        x_sub1 = self.convbnrelu3_sub1(x_sub1_)

        x_sub1_inst = self.convbnrelu3_sub1_inst(x_sub1_)


        x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2)
        x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1)

        x_sub24_inst, sub4_inst_fe = self.cff_sub24_inst(x_sub4_inst, x_sub2_inst)
        x_sub12_inst, sub24_inst_fe = self.cff_sub12_inst(x_sub24_inst, x_sub1_inst)

        x_sub12 = F.interpolate(
            x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True
        )
        x_sub12_inst = F.interpolate(
            x_sub12_inst, size=get_interp_size(x_sub12_inst, z_factor=2), mode="bilinear", align_corners=True
        )
        # x_sub4 = self.res_block3_identity(x_sub4)
        sub124_cls = self.classification(x_sub12)
        sub124_inst_fe = self.inst_feature_embedding(x_sub12_inst)

        if self.training:
            return (sub124_cls, sub24_cls, sub4_cls, sub124_inst_fe, sub24_inst_fe, sub4_inst_fe)
        else:
            sub124_cls = F.interpolate(
                sub124_cls,
                size=get_interp_size(sub124_cls, z_factor=4),
                mode="bilinear",
                align_corners=True,
            )
            sub124_inst_fe = F.interpolate(
                sub124_inst_fe,
                size=get_interp_size(sub124_inst_fe, z_factor=4),
                mode="bilinear",
                align_corners=True,
            )
            return sub124_cls, sub124_inst_fe