Ejemplo n.º 1
0
    def forward(self, refine, x0, x1, x2, x3, coarse):
        if not self.training:
            return self.inference(refine, x0, x1, x2, x3, coarse)

        # coarse size: 48x48
        # rend stage 1 with layer3
        temp1 = coarse
        points1 = sampling_points_v2(torch.sigmoid(temp1), N=512, k=3, beta=0.75)
        coarse_feature = sampling_features(temp1, points1, align_corners=False)
        fine_feature = sampling_features(x3, points1, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend1 = self.mlp3(feature_representation)

        # coarse size: 48x48
        # rend stage 2 with layer2
        temp2 = coarse
        points2 = sampling_points_v2(torch.sigmoid(temp2), N=512, k=3, beta=0.75)
        coarse_feature = sampling_features(temp2, points2, align_corners=False)
        fine_feature = sampling_features(x2, points2, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend2 = self.mlp2(feature_representation)

        # coarse size: 96x96
        # rend stage 3 with layer1
        temp3 = F.interpolate(temp2, scale_factor=2, mode='bilinear', align_corners=False)
        points3 = sampling_points_v2(torch.sigmoid(temp3), N=2048, k=3, beta=0.75)
        coarse_feature = sampling_features(temp3, points3, align_corners=False)
        fine_feature = sampling_features(x1, points3, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend3 = self.mlp1(feature_representation)

        # coarse size: 192x192
        # rend stage 4 with layer0
        temp4 = F.interpolate(temp3, scale_factor=2, mode='bilinear', align_corners=False)
        points4 = sampling_points_v2(torch.sigmoid(temp4), N=2048, k=3, beta=0.75)
        coarse_feature = sampling_features(temp4, points4, align_corners=False)
        fine_feature = sampling_features(x0, points4, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend4 = self.mlp0(feature_representation)

        # coarse size: 384x384
        # rend stage 5 with layer refined
        temp5 = F.interpolate(temp4, scale_factor=2, mode='bilinear', align_corners=False)
        points5 = sampling_points_v2(torch.sigmoid(temp5), N=2048, k=3, beta=0.75)
        coarse_feature = sampling_features(temp5, points5, align_corners=False)
        fine_feature = sampling_features(refine, points5, align_corners=False)
        feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        rend5 = self.mlp_refine(feature_representation)

        return {
            "coarse": coarse,
            "stage1": [points1, rend1],
            "stage2": [points2, rend2],
            "stage3": [points3, rend3],
            "stage4": [points4, rend4],
            "stage5": [points5, rend5],
        }
Ejemplo n.º 2
0
    def inference(self, x, feature, mask):

        num_points = 512
        while mask.shape[-1] != x.shape[-1]:
            mask = F.interpolate(mask,
                                 scale_factor=2,
                                 mode="bilinear",
                                 align_corners=False)

            points_idx, points = sampling_points_v2(torch.sigmoid(mask),
                                                    num_points,
                                                    training=self.training)

            coarse = sampling_features(mask, points, align_corners=False)
            fine = sampling_features(feature, points, align_corners=False)

            feature_representation = torch.cat([coarse, fine], dim=1)

            rend = self.mlp(feature_representation)

            #print(rend.min())

            B, C, H, W = mask.shape

            #print(mask.shape)

            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
            mask = (mask.reshape(B, C, -1).scatter_(2, points_idx,
                                                    rend).view(B, C, H, W))

            #print(mask.shape)

        return {"fine": mask}
Ejemplo n.º 3
0
    def forward(self, x, feature, mask):

        if not self.training:
            return self.inference(x, feature, mask)

        num_points = 2048
        points = sampling_points_v2(torch.sigmoid(mask), num_points, self.k,
                                    self.beta)
        coarse = sampling_features(mask, points, align_corners=False)
        fine = sampling_features(feature, points, align_corners=False)
        feature_representation = torch.cat([coarse, fine], dim=1)
        rend = self.mlp(feature_representation)

        return {"rend": rend, "points": points, "coarse": mask}
Ejemplo n.º 4
0
    def inference(self, refine, x0, x1, x2, x3, coarse):
        # stage 1
        # coarse size: 48x48
        # temp = coarse
        # points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training)
        # coarse_feature = sampling_features(temp, points, align_corners=False)
        # fine_feature = sampling_features(x3, points, align_corners=False)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend = self.mlp3(feature_representation)
        # B, C, H, W = coarse.shape
        # points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        # coarse1 = (coarse.reshape(B, C, -1)
        #            .scatter_(2, points_idx, rend)
        #            .view(B, C, H, W))

        # stage 2
        # 48x48
        # temp = coarse1
        # points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1), 512, training=self.training)
        # coarse_feature = sampling_features(temp, points, align_corners=True)
        # fine_feature = sampling_features(x2, points, align_corners=True)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend = self.mlp2(feature_representation)
        # B, C, H, W = coarse1.shape
        # points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        # coarse2 = (coarse1.reshape(B, C, -1)
        #            .scatter_(2, points_idx, rend)
        #            .view(B, C, H, W))

        # stage 3
        # 96x96
        coarse3 = F.interpolate(coarse,
                                scale_factor=2,
                                mode='bilinear',
                                align_corners=True)
        temp = coarse3
        points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1),
                                                512,
                                                training=self.training)
        coarse_feature = sampling_features(temp, points, align_corners=True)
        fine_feature = sampling_features(x1, points, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend = self.mlp1(feature_representation)
        B, C, H, W = coarse3.shape
        points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        coarse3 = (coarse3.reshape(B, C, -1).scatter_(2, points_idx,
                                                      rend).view(B, C, H, W))

        # stage 4
        # 192x192
        coarse4 = F.interpolate(coarse3,
                                scale_factor=2,
                                mode='bilinear',
                                align_corners=True)
        temp = coarse4
        points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1),
                                                512,
                                                training=self.training)
        coarse_feature = sampling_features(temp, points, align_corners=True)
        fine_feature = sampling_features(x0, points, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend = self.mlp0(feature_representation)
        B, C, H, W = coarse4.shape
        points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        coarse4 = (coarse4.reshape(B, C, -1).scatter_(2, points_idx,
                                                      rend).view(B, C, H, W))

        # stage 5
        # 384x384
        coarse5 = F.interpolate(coarse4,
                                scale_factor=2,
                                mode='bilinear',
                                align_corners=True)
        temp = coarse5
        points_idx, points = sampling_points_v2(torch.softmax(temp, dim=1),
                                                512,
                                                training=self.training)
        coarse_feature = sampling_features(temp, points, align_corners=True)
        fine_feature = sampling_features(refine, points, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend = self.mlp_refine(feature_representation)
        B, C, H, W = coarse5.shape
        points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
        coarse5 = (coarse5.reshape(B, C, -1).scatter_(2, points_idx,
                                                      rend).view(B, C, H, W))

        return {"fine": coarse5}
Ejemplo n.º 5
0
    def forward(self, refine, x0, x1, x2, x3, coarse):
        if not self.training:
            return self.inference(refine, x0, x1, x2, x3, coarse)

        # coarse size: 48x48
        # rend stage 1 with layer3
        # temp1 = coarse
        # # print("temp1 value: ", temp1.max(), temp1.min(), temp1.shape)
        # points1 = sampling_points_v2(torch.softmax(temp1, dim=1), N=512, k=3, beta=0.75)
        # coarse_feature = sampling_features(temp1, points1, align_corners=False)
        # fine_feature = sampling_features(x3, points1, align_corners=False)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend1 = self.mlp3(feature_representation)

        # coarse size: 48x48
        # rend stage 2 with layer2
        # temp2 = coarse
        # # print("temp2 value: ", temp2.max(), temp2.min(), temp2.shape)
        # points2 = sampling_points_v2(torch.softmax(temp2, dim=1), N=512, k=3, beta=0.75)
        # coarse_feature = sampling_features(temp2, points2, align_corners=False)
        # fine_feature = sampling_features(x2, points2, align_corners=False)
        # feature_representation = torch.cat([coarse_feature, fine_feature], dim=1)
        # rend2 = self.mlp2(feature_representation)

        # coarse size: 96x96
        # rend stage 3 with layer1
        temp3 = F.interpolate(coarse,
                              scale_factor=2,
                              mode='bilinear',
                              align_corners=True)
        # print("temp3 value: ", temp3.max(), temp3.min(), temp3.shape)
        points3 = sampling_points_v2(torch.softmax(temp3, dim=1),
                                     N=2048,
                                     k=3,
                                     beta=0.75)
        coarse_feature = sampling_features(temp3, points3, align_corners=True)
        fine_feature = sampling_features(x1, points3, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend3 = self.mlp1(feature_representation)

        # coarse size: 192x192
        # rend stage 4 with layer0
        temp4 = F.interpolate(temp3,
                              scale_factor=2,
                              mode='bilinear',
                              align_corners=True)
        # print("temp4 value: ", temp4.max(), temp4.min(), temp4.shape)
        points4 = sampling_points_v2(torch.softmax(temp4, dim=1),
                                     N=2048,
                                     k=3,
                                     beta=0.75)
        coarse_feature = sampling_features(temp4, points4, align_corners=True)
        fine_feature = sampling_features(x0, points4, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend4 = self.mlp0(feature_representation)

        # coarse size: 384x384
        # rend stage 5 with layer refined
        temp5 = F.interpolate(temp4,
                              scale_factor=2,
                              mode='bilinear',
                              align_corners=True)
        # print("temp5 value: ", temp5.max(), temp5.min(), temp5.shape)
        points5 = sampling_points_v2(torch.softmax(temp5, dim=1),
                                     N=2048,
                                     k=3,
                                     beta=0.75)
        coarse_feature = sampling_features(temp5, points5, align_corners=True)
        fine_feature = sampling_features(refine, points5, align_corners=True)
        feature_representation = torch.cat([coarse_feature, fine_feature],
                                           dim=1)
        rend5 = self.mlp_refine(feature_representation)

        return {
            "coarse": coarse,
            # "stage1": [points1, rend1],
            # "stage2": [points2, rend2],
            "stage3": [points3, rend3],
            "stage4": [points4, rend4],
            "stage5": [points5, rend5],
        }