示例#1
0
    def inference(self, x, feature, mask):

        num_points = 512
        while mask.shape[-1] != x.shape[-1]:
            mask = F.interpolate(mask,
                                 scale_factor=2,
                                 mode="bilinear",
                                 align_corners=False)

            points_idx, points = sampling_points_v2(torch.sigmoid(mask),
                                                    num_points,
                                                    training=self.training)

            coarse = sampling_features(mask, points, align_corners=False)
            fine = sampling_features(feature, points, align_corners=False)

            feature_representation = torch.cat([coarse, fine], dim=1)

            rend = self.mlp(feature_representation)

            #print(rend.min())

            B, C, H, W = mask.shape

            #print(mask.shape)

            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
            mask = (mask.reshape(B, C, -1).scatter_(2, points_idx,
                                                    rend).view(B, C, H, W))

            #print(mask.shape)

        return {"fine": mask}
示例#2
0
    def forward(self, output, mask):
        # coarse, stage1, stage2, stage3, stage4, stage5 = output.values()
        coarse, stage3, stage4, stage5 = output.values()

        pred0 = F.interpolate(coarse, mask.shape[-2:], mode="bilinear", align_corners=True)

        # rend1 = stage1[1]
        # gt_points1 = sampling_features(mask, stage1[0], mode='nearest', align_corners=False).argmax(dim=1)
        # # print(rend1.shape, gt_points1.shape)
        # point_loss1 = F.cross_entropy(rend1, gt_points1)
        #
        # rend2 = stage2[1]
        # gt_points2 = sampling_features(mask, stage2[0], mode='nearest', align_corners=False).argmax(dim=1)
        # point_loss2 = F.cross_entropy(rend2, gt_points2)

        rend3 = stage3[1]
        gt_points3 = sampling_features(mask, stage3[0], mode='nearest', align_corners=True).argmax(dim=1)
        point_loss3 = F.cross_entropy(rend3, gt_points3)

        rend4 = stage4[1]
        gt_points4 = sampling_features(mask, stage4[0], mode='nearest', align_corners=True).argmax(dim=1)
        point_loss4 = F.cross_entropy(rend4, gt_points4)

        rend5 = stage5[1]
        gt_points5 = sampling_features(mask, stage5[0], mode='nearest', align_corners=True).argmax(dim=1)
        point_loss5 = F.cross_entropy(rend5, gt_points5)

        mask = mask.argmax(dim=1)
        seg_loss = F.cross_entropy(pred0, mask)
        # point_loss = point_loss1 + point_loss2 + point_loss3 + point_loss4 + point_loss5
        point_loss = point_loss3 + point_loss4 + point_loss5

        loss = point_loss + seg_loss

        return loss
示例#3
0
    def forward(self, refine, x0, x1, x2, x3, coarse):
        if not self.training:
            return self.inference(refine, x0, x1, x2, x3, coarse)

        # coarse size: 48x48
        # rend stage 1 with layer3
        temp1 = coarse
        points1 = sampling_points_v2(torch.sigmoid(temp1), N=512, k=3, beta=0.75)
        coarse_feature = sampling_features(temp1, points1, align_corners=False)
        fine_feature = sampling_features(x3, points1, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend1 = self.mlp3(feature_representation)

        # coarse size: 48x48
        # rend stage 2 with layer2
        temp2 = coarse
        points2 = sampling_points_v2(torch.sigmoid(temp2), N=512, k=3, beta=0.75)
        coarse_feature = sampling_features(temp2, points2, align_corners=False)
        fine_feature = sampling_features(x2, points2, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend2 = self.mlp2(feature_representation)

        # coarse size: 96x96
        # rend stage 3 with layer1
        temp3 = F.interpolate(temp2, scale_factor=2, mode='bilinear', align_corners=False)
        points3 = sampling_points_v2(torch.sigmoid(temp3), N=2048, k=3, beta=0.75)
        coarse_feature = sampling_features(temp3, points3, align_corners=False)
        fine_feature = sampling_features(x1, points3, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend3 = self.mlp1(feature_representation)

        # coarse size: 192x192
        # rend stage 4 with layer0
        temp4 = F.interpolate(temp3, scale_factor=2, mode='bilinear', align_corners=False)
        points4 = sampling_points_v2(torch.sigmoid(temp4), N=2048, k=3, beta=0.75)
        coarse_feature = sampling_features(temp4, points4, align_corners=False)
        fine_feature = sampling_features(x0, points4, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend4 = self.mlp0(feature_representation)

        # coarse size: 384x384
        # rend stage 5 with layer refined
        temp5 = F.interpolate(temp4, scale_factor=2, mode='bilinear', align_corners=False)
        points5 = sampling_points_v2(torch.sigmoid(temp5), N=2048, k=3, beta=0.75)
        coarse_feature = sampling_features(temp5, points5, align_corners=False)
        fine_feature = sampling_features(refine, points5, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend5 = self.mlp_refine(feature_representation)

        return {
            "coarse": coarse,
            "stage1": [points1, rend1],
            "stage2": [points2, rend2],
            "stage3": [points3, rend3],
            "stage4": [points4, rend4],
            "stage5": [points5, rend5],
        }
示例#4
0
    def forward(self, x, feature, mask):

        if not self.training:
            return self.inference(x, feature, mask)

        num_points = 2048
        points = sampling_points_v2(torch.sigmoid(mask), num_points, self.k,
                                    self.beta)
        coarse = sampling_features(mask, points, align_corners=False)
        fine = sampling_features(feature, points, align_corners=False)
        feature_representation = torch.cat([coarse, fine], dim=1)
        rend = self.mlp(feature_representation)

        return {"rend": rend, "points": points, "coarse": mask}
示例#5
0
    def forward(self, output, mask):

        pred = F.interpolate(output['coarse'], mask.shape[-2:], mode="bilinear", align_corners=True)
        gt_points = sampling_features(mask, output['points'], mode='bilinear', align_corners=True).argmax(dim=1)
        mask = mask.argmax(dim=1)
        seg_loss = F.cross_entropy(pred, mask)
        point_loss = F.cross_entropy(output['rend'], gt_points)

        loss = seg_loss + point_loss

        return loss
示例#6
0
    def forward(self, output, mask):

        coarse, stage1, stage2, stage3, stage4, stage5 = output.values()

        # coarse, stage3, stage4, stage5 = output.values()

        pred0 = F.interpolate(coarse, mask.shape[-2:], mode="bilinear", align_corners=False)
        seg_loss = F.binary_cross_entropy_with_logits(pred0, mask)

        rend1 = stage1[1]
        gt_points1 = sampling_features(mask, stage1[0], mode='nearest')
        point_loss1 = F.binary_cross_entropy_with_logits(rend1, gt_points1)
        # point_loss1 = self.loss(torch.sigmoid(rend1), gt_points1)

        rend2 = stage2[1]
        gt_points2 = sampling_features(mask, stage2[0], mode='nearest')
        point_loss2 = F.binary_cross_entropy_with_logits(rend2, gt_points2)
        # point_loss2 = self.loss(torch.sigmoid(rend2), gt_points2)

        rend3 = stage3[1]
        gt_points3 = sampling_features(mask, stage3[0], mode='nearest')
        point_loss3 = F.binary_cross_entropy_with_logits(rend3, gt_points3)
        # point_loss3 = self.loss(torch.sigmoid(rend3), gt_points3)

        rend4 = stage4[1]
        gt_points4 = sampling_features(mask, stage4[0], mode='nearest')
        point_loss4 = F.binary_cross_entropy_with_logits(rend4, gt_points4)
        # point_loss4 = self.loss(torch.sigmoid(rend4), gt_points4)

        rend5 = stage5[1]
        gt_points5 = sampling_features(mask, stage5[0], mode='nearest')
        point_loss5 = F.binary_cross_entropy_with_logits(rend5, gt_points5)
        # point_loss5 = self.loss(torch.sigmoid(rend5), gt_points5)

        # point_loss = point_loss1 + point_loss2 + point_loss3 + point_loss4 + point_loss5

        point_loss = point_loss3 + point_loss4 + point_loss5

        loss = seg_loss + point_loss

        return loss
示例#7
0
    def forward(self, output, mask):
        pred = torch.sigmoid(
            F.upsample(output['coarse'],
                       mask.shape[-2:],
                       mode="bilinear",
                       align_corners=True))
        gt_points = sampling_features(mask, output['points'], mode='nearest')

        N = mask.size(0)
        smooth = 1
        input_flat = pred.view(N, -1)
        target_flat = mask.view(N, -1)
        intersection = input_flat * target_flat
        seg_loss = 2 * (intersection.sum(1) + smooth) / (
            input_flat.sum(1) + target_flat.sum(1) + smooth)
        seg_loss = 1 - seg_loss.sum() / N

        point_loss = F.binary_cross_entropy(torch.sigmoid(output['rend']),
                                            gt_points)

        loss = seg_loss + point_loss

        return loss
示例#8
0
    def forward(self, output, mask):
        pred = F.interpolate(output['coarse'],
                             mask.shape[-2:],
                             mode="bilinear",
                             align_corners=False)
        gt_points = sampling_features(mask, output['points'], mode='nearest')

        # N = mask.size(0)

        # smooth = 1
        # input_flat = pred.view(N, -1)
        # target_flat = mask.view(N, -1)
        # intersection = input_flat * target_flat
        # seg_loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
        # seg_loss = 1 - seg_loss.sum() / N

        seg_loss = F.binary_cross_entropy_with_logits(pred, mask)

        point_loss = F.binary_cross_entropy_with_logits(
            output['rend'], gt_points)

        loss = seg_loss + point_loss

        return loss
示例#9
0
    def inference(self, refine, x0, x1, x2, x3, coarse):
        # stage 1
        # coarse size: 48x48
        # temp = coarse
        # points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training)
        # coarse_feature = sampling_features(temp, points, align_corners=False)
        # fine_feature = sampling_features(x3, points, align_corners=False)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend = self.mlp3(feature_representation)
        # B, C, H, W = coarse.shape
        # points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        # coarse1 = (coarse.reshape(B, C, -1)
        #            .scatter_(2, points_idx, rend)
        #            .view(B, C, H, W))

        # stage 2
        # 48x48
        # temp = coarse1
        # points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training)
        # coarse_feature = sampling_features(temp, points, align_corners=True)
        # fine_feature = sampling_features(x2, points, align_corners=True)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend = self.mlp2(feature_representation)
        # B, C, H, W = coarse1.shape
        # points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        # coarse2 = (coarse1.reshape(B, C, -1)
        #            .scatter_(2, points_idx, rend)
        #            .view(B, C, H, W))

        # stage 3
        # 96x96
        coarse3 = F.interpolate(coarse,
                                scale_factor=2,
                                mode='bilinear',
                                align_corners=True)
        temp = coarse3
        points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1),
                                                512,
                                                training=self.training)
        coarse_feature = sampling_features(temp, points, align_corners=True)
        fine_feature = sampling_features(x1, points, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend = self.mlp1(feature_representation)
        B, C, H, W = coarse3.shape
        points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        coarse3 = (coarse3.reshape(B, C, -1).scatter_(2, points_idx,
                                                      rend).view(B, C, H, W))

        # stage 4
        # 192x192
        coarse4 = F.interpolate(coarse3,
                                scale_factor=2,
                                mode='bilinear',
                                align_corners=True)
        temp = coarse4
        points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1),
                                                512,
                                                training=self.training)
        coarse_feature = sampling_features(temp, points, align_corners=True)
        fine_feature = sampling_features(x0, points, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend = self.mlp0(feature_representation)
        B, C, H, W = coarse4.shape
        points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        coarse4 = (coarse4.reshape(B, C, -1).scatter_(2, points_idx,
                                                      rend).view(B, C, H, W))

        # stage 5
        # 384x384
        coarse5 = F.interpolate(coarse4,
                                scale_factor=2,
                                mode='bilinear',
                                align_corners=True)
        temp = coarse5
        points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1),
                                                512,
                                                training=self.training)
        coarse_feature = sampling_features(temp, points, align_corners=True)
        fine_feature = sampling_features(refine, points, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend = self.mlp_refine(feature_representation)
        B, C, H, W = coarse5.shape
        points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        coarse5 = (coarse5.reshape(B, C, -1).scatter_(2, points_idx,
                                                      rend).view(B, C, H, W))

        return {"fine": coarse5}
示例#10
0
    def forward(self, refine, x0, x1, x2, x3, coarse):
        if not self.training:
            return self.inference(refine, x0, x1, x2, x3, coarse)

        # coarse size: 48x48
        # rend stage 1 with layer3
        # temp1 = coarse
        # # print("temp1 value: ", temp1.max(), temp1.min(), temp1.shape)
        # points1 = sampling_points_v2(torch.softmax(temp1, dim=1), N=512, k=3, beta=0.75)
        # coarse_feature = sampling_features(temp1, points1, align_corners=False)
        # fine_feature = sampling_features(x3, points1, align_corners=False)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend1 = self.mlp3(feature_representation)

        # coarse size: 48x48
        # rend stage 2 with layer2
        # temp2 = coarse
        # # print("temp2 value: ", temp2.max(), temp2.min(), temp2.shape)
        # points2 = sampling_points_v2(torch.softmax(temp2, dim=1), N=512, k=3, beta=0.75)
        # coarse_feature = sampling_features(temp2, points2, align_corners=False)
        # fine_feature = sampling_features(x2, points2, align_corners=False)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend2 = self.mlp2(feature_representation)

        # coarse size: 96x96
        # rend stage 3 with layer1
        temp3 = F.interpolate(coarse,
                              scale_factor=2,
                              mode='bilinear',
                              align_corners=True)
        # print("temp3 value: ", temp3.max(), temp3.min(), temp3.shape)
        points3 = sampling_points_v2(torch.softmax(temp3, dim=1),
                                     N=2048,
                                     k=3,
                                     beta=0.75)
        coarse_feature = sampling_features(temp3, points3, align_corners=True)
        fine_feature = sampling_features(x1, points3, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend3 = self.mlp1(feature_representation)

        # coarse size: 192x192
        # rend stage 4 with layer0
        temp4 = F.interpolate(temp3,
                              scale_factor=2,
                              mode='bilinear',
                              align_corners=True)
        # print("temp4 value: ", temp4.max(), temp4.min(), temp4.shape)
        points4 = sampling_points_v2(torch.softmax(temp4, dim=1),
                                     N=2048,
                                     k=3,
                                     beta=0.75)
        coarse_feature = sampling_features(temp4, points4, align_corners=True)
        fine_feature = sampling_features(x0, points4, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend4 = self.mlp0(feature_representation)

        # coarse size: 384x384
        # rend stage 5 with layer refined
        temp5 = F.interpolate(temp4,
                              scale_factor=2,
                              mode='bilinear',
                              align_corners=True)
        # print("temp5 value: ", temp5.max(), temp5.min(), temp5.shape)
        points5 = sampling_points_v2(torch.softmax(temp5, dim=1),
                                     N=2048,
                                     k=3,
                                     beta=0.75)
        coarse_feature = sampling_features(temp5, points5, align_corners=True)
        fine_feature = sampling_features(refine, points5, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend5 = self.mlp_refine(feature_representation)

        return {
            "coarse": coarse,
            # "stage1": [points1, rend1],
            # "stage2": [points2, rend2],
            "stage3": [points3, rend3],
            "stage4": [points4, rend4],
            "stage5": [points5, rend5],
        }
示例#11
0
        self.rend = RendNet(n_class=n_class)

    def forward(self, x):

        refine, x0, x1, x2, x3, coarse = self.seg(x)
        res = self.rend(refine, x0, x1, x2, x3, coarse)
        return res

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


if __name__ == "__main__":
    model = RendUNet()
    img = torch.rand(4, 3, 512, 512)
    mask = torch.rand(4, 3, 512, 512)
    model.train()
    print('# parameters:', sum(param.numel() for param in model.parameters()))
    res = model(img)
    for k, v in res.items():
        if k == "coarse":
            print(k, v.shape)
        else:
            print(k, v[0].shape, v[1].shape,
                  sampling_features(mask, v[0]).shape)
示例#12
0

class RendDANet(BaseNet):
    def __init__(self, nclass, backbone, norm_layer=nn.BatchNorm2d):
        super(RendDANet, self).__init__(nclass,
                                        backbone,
                                        norm_layer=norm_layer)
        self.head = DANetHead(2048, 512, norm_layer=norm_layer)
        self.seg1 = nn.Sequential(nn.Dropout(0.1), nn.Conv2d(512, nclass, 1))
        self.rend_head = PointHead(in_c=527, num_classes=nclass)

    def forward(self, x):
        _, c2, _, c4 = self.base_forward(x)

        mask = self.seg1(self.head(c4))
        result = self.rend_head(x, c2, mask)

        return result


if __name__ == "__main__":
    net = RendDANet(backbone='resnet101', nclass=15)
    img = torch.rand(4, 3, 384, 384)
    mask = torch.rand(4, 8, 384, 384)
    net.train()
    output = net(img)
    for k, v in output.items():
        print(k, v.shape)
    test = sampling_features(mask, output['points'], mode='nearest')
    print(test.shape)