コード例 #1
0
ファイル: foldingnet.py プロジェクト: mahaling/FoldingNet
    def __init__(self,
                 MLP_dims,
                 FC_dims,
                 grid_dims,
                 Folding1_dims,
                 Folding2_dims,
                 MLP_doLastRelu=False):
        assert (MLP_dims[-1] == FC_dims[0])
        super(FoldingNetVanilla, self).__init__()
        # Encoder
        #   PointNet
        self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu)

        # Decoder
        #   Folding
        #     2D grid: (grid_dims(0) * grid_dims(1)) x 2
        # TODO: normalize the grid to align with the input data
        self.N = grid_dims[0] * grid_dims[1]
        u = (torch.arange(0, grid_dims[0]) / grid_dims[0] - 0.5).repeat(
            grid_dims[1])
        v = (torch.arange(0, grid_dims[1]) / grid_dims[1] - 0.5).expand(
            grid_dims[0], -1).t().reshape(-1)
        self.grid = torch.stack((u, v), 1)  # Nx2

        #     1st folding
        self.Fold1 = FoldingNetSingle(Folding1_dims)
        #     2nd folding
        self.Fold2 = FoldingNetSingle(Folding2_dims)
コード例 #2
0
ファイル: foldingnet.py プロジェクト: mahaling/FoldingNet
    def __init__(self,
                 MLP_dims,
                 FC_dims,
                 Folding1_dims,
                 Folding2_dims,
                 MLP_doLastRelu=False):
        assert (MLP_dims[-1] == FC_dims[0])
        super(FoldingNetShapes, self).__init__()
        # Encoder
        #   PointNet
        self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu)

        # Decoder
        #   Folding
        self.box = make_box()  # 18 * 18 * 6 points
        self.cylinder = make_cylinder()  # same as 1944
        self.sphere = make_sphere()  # 1944 points
        self.grid = torch.Tensor(
            np.hstack((self.box, self.cylinder, self.sphere)))

        #     1st folding
        self.Fold1 = FoldingNetSingle(Folding1_dims)
        #     2nd folding
        self.Fold2 = FoldingNetSingle(Folding2_dims)
        self.N = 1944  # number of points needed to replicate codeword later; also points in Grid
        self.fc = nn.Linear(9, 9, True)  # geometric transformation
コード例 #3
0
ファイル: foldingnet.py プロジェクト: mahaling/FoldingNet
class FoldingNetShapes(nn.Module):
    ## add 3 shapes to choose and a learnable layer
    def __init__(self,
                 MLP_dims,
                 FC_dims,
                 Folding1_dims,
                 Folding2_dims,
                 MLP_doLastRelu=False):
        assert (MLP_dims[-1] == FC_dims[0])
        super(FoldingNetShapes, self).__init__()
        # Encoder
        #   PointNet
        self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu)

        # Decoder
        #   Folding
        self.box = make_box()  # 18 * 18 * 6 points
        self.cylinder = make_cylinder()  # same as 1944
        self.sphere = make_sphere()  # 1944 points
        self.grid = torch.Tensor(
            np.hstack((self.box, self.cylinder, self.sphere)))

        #     1st folding
        self.Fold1 = FoldingNetSingle(Folding1_dims)
        #     2nd folding
        self.Fold2 = FoldingNetSingle(Folding2_dims)
        self.N = 1944  # number of points needed to replicate codeword later; also points in Grid
        self.fc = nn.Linear(9, 9, True)  # geometric transformation

    def forward(self, X):
        # encoding
        f = self.PointNet.forward(X)  # BxK
        f = f.unsqueeze(1)  # Bx1xK
        codeword = f.expand(-1, self.N, -1)  # BxNxK

        # cat 2d grid and feature
        B = codeword.shape[0]  # extract batch size
        if not X.is_cuda:
            tmpGrid = self.grid  # Nx9
        else:
            tmpGrid = self.grid.cuda()  # Nx9
        tmpGrid = tmpGrid.unsqueeze(0)
        tmpGrid = tmpGrid.expand(B, -1, -1)  # BxNx9
        tmpGrid = self.fc(tmpGrid)  # transform

        # 1st folding
        f = torch.cat((tmpGrid, codeword), 2)  # BxNx(K+9)
        # print(tmpGrid.shape)
        # print(codeword.shape)
        # print(f.shape)
        # print(self.Fold1)
        f = self.Fold1.forward(f)  # BxNx3

        # 2nd folding
        f = torch.cat((f, codeword), 2)  # BxNx(K+3)
        f = self.Fold2.forward(f)  # BxNx3
        return f
コード例 #4
0
ファイル: foldingnet.py プロジェクト: mahaling/FoldingNet
class FoldingNetVanilla(nn.Module):  # PointNetVanilla or nn.Sequential
    def __init__(self,
                 MLP_dims,
                 FC_dims,
                 grid_dims,
                 Folding1_dims,
                 Folding2_dims,
                 MLP_doLastRelu=False):
        assert (MLP_dims[-1] == FC_dims[0])
        super(FoldingNetVanilla, self).__init__()
        # Encoder
        #   PointNet
        self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu)

        # Decoder
        #   Folding
        #     2D grid: (grid_dims(0) * grid_dims(1)) x 2
        # TODO: normalize the grid to align with the input data
        self.N = grid_dims[0] * grid_dims[1]
        u = (torch.arange(0, grid_dims[0]) / grid_dims[0] - 0.5).repeat(
            grid_dims[1])
        v = (torch.arange(0, grid_dims[1]) / grid_dims[1] - 0.5).expand(
            grid_dims[0], -1).t().reshape(-1)
        self.grid = torch.stack((u, v), 1)  # Nx2

        #     1st folding
        self.Fold1 = FoldingNetSingle(Folding1_dims)
        #     2nd folding
        self.Fold2 = FoldingNetSingle(Folding2_dims)

    def forward(self, X):
        # encoding
        f = self.PointNet.forward(X)  # BxK
        f = f.unsqueeze(1)  # Bx1xK
        codeword = f.expand(-1, self.N, -1)  # BxNxK

        # cat 2d grid and feature
        B = codeword.shape[0]  # extract batch size
        if not X.is_cuda:
            tmpGrid = self.grid  # Nx2
        else:
            tmpGrid = self.grid.cuda()  # Nx2
        tmpGrid = tmpGrid.unsqueeze(0)
        tmpGrid = tmpGrid.expand(B, -1, -1)  # BxNx2

        # 1st folding
        f = torch.cat((tmpGrid, codeword), 2)  # BxNx(K+2)
        f = self.Fold1.forward(f)  # BxNx3

        # 2nd folding
        f = torch.cat((f, codeword), 2)  # BxNx(K+3)
        f = self.Fold2.forward(f)  # BxNx3
        return f