def __init__(self, MLP_dims, FC_dims, grid_dims, Folding1_dims, Folding2_dims, MLP_doLastRelu=False): assert (MLP_dims[-1] == FC_dims[0]) super(FoldingNetVanilla, self).__init__() # Encoder # PointNet self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu) # Decoder # Folding # 2D grid: (grid_dims(0) * grid_dims(1)) x 2 # TODO: normalize the grid to align with the input data self.N = grid_dims[0] * grid_dims[1] u = (torch.arange(0, grid_dims[0]) / grid_dims[0] - 0.5).repeat( grid_dims[1]) v = (torch.arange(0, grid_dims[1]) / grid_dims[1] - 0.5).expand( grid_dims[0], -1).t().reshape(-1) self.grid = torch.stack((u, v), 1) # Nx2 # 1st folding self.Fold1 = FoldingNetSingle(Folding1_dims) # 2nd folding self.Fold2 = FoldingNetSingle(Folding2_dims)
def __init__(self, MLP_dims, FC_dims, Folding1_dims, Folding2_dims, MLP_doLastRelu=False): assert (MLP_dims[-1] == FC_dims[0]) super(FoldingNetShapes, self).__init__() # Encoder # PointNet self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu) # Decoder # Folding self.box = make_box() # 18 * 18 * 6 points self.cylinder = make_cylinder() # same as 1944 self.sphere = make_sphere() # 1944 points self.grid = torch.Tensor( np.hstack((self.box, self.cylinder, self.sphere))) # 1st folding self.Fold1 = FoldingNetSingle(Folding1_dims) # 2nd folding self.Fold2 = FoldingNetSingle(Folding2_dims) self.N = 1944 # number of points needed to replicate codeword later; also points in Grid self.fc = nn.Linear(9, 9, True) # geometric transformation
class FoldingNetShapes(nn.Module): ## add 3 shapes to choose and a learnable layer def __init__(self, MLP_dims, FC_dims, Folding1_dims, Folding2_dims, MLP_doLastRelu=False): assert (MLP_dims[-1] == FC_dims[0]) super(FoldingNetShapes, self).__init__() # Encoder # PointNet self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu) # Decoder # Folding self.box = make_box() # 18 * 18 * 6 points self.cylinder = make_cylinder() # same as 1944 self.sphere = make_sphere() # 1944 points self.grid = torch.Tensor( np.hstack((self.box, self.cylinder, self.sphere))) # 1st folding self.Fold1 = FoldingNetSingle(Folding1_dims) # 2nd folding self.Fold2 = FoldingNetSingle(Folding2_dims) self.N = 1944 # number of points needed to replicate codeword later; also points in Grid self.fc = nn.Linear(9, 9, True) # geometric transformation def forward(self, X): # encoding f = self.PointNet.forward(X) # BxK f = f.unsqueeze(1) # Bx1xK codeword = f.expand(-1, self.N, -1) # BxNxK # cat 2d grid and feature B = codeword.shape[0] # extract batch size if not X.is_cuda: tmpGrid = self.grid # Nx9 else: tmpGrid = self.grid.cuda() # Nx9 tmpGrid = tmpGrid.unsqueeze(0) tmpGrid = tmpGrid.expand(B, -1, -1) # BxNx9 tmpGrid = self.fc(tmpGrid) # transform # 1st folding f = torch.cat((tmpGrid, codeword), 2) # BxNx(K+9) # print(tmpGrid.shape) # print(codeword.shape) # print(f.shape) # print(self.Fold1) f = self.Fold1.forward(f) # BxNx3 # 2nd folding f = torch.cat((f, codeword), 2) # BxNx(K+3) f = self.Fold2.forward(f) # BxNx3 return f
class FoldingNetVanilla(nn.Module): # PointNetVanilla or nn.Sequential def __init__(self, MLP_dims, FC_dims, grid_dims, Folding1_dims, Folding2_dims, MLP_doLastRelu=False): assert (MLP_dims[-1] == FC_dims[0]) super(FoldingNetVanilla, self).__init__() # Encoder # PointNet self.PointNet = PointNetVanilla(MLP_dims, FC_dims, MLP_doLastRelu) # Decoder # Folding # 2D grid: (grid_dims(0) * grid_dims(1)) x 2 # TODO: normalize the grid to align with the input data self.N = grid_dims[0] * grid_dims[1] u = (torch.arange(0, grid_dims[0]) / grid_dims[0] - 0.5).repeat( grid_dims[1]) v = (torch.arange(0, grid_dims[1]) / grid_dims[1] - 0.5).expand( grid_dims[0], -1).t().reshape(-1) self.grid = torch.stack((u, v), 1) # Nx2 # 1st folding self.Fold1 = FoldingNetSingle(Folding1_dims) # 2nd folding self.Fold2 = FoldingNetSingle(Folding2_dims) def forward(self, X): # encoding f = self.PointNet.forward(X) # BxK f = f.unsqueeze(1) # Bx1xK codeword = f.expand(-1, self.N, -1) # BxNxK # cat 2d grid and feature B = codeword.shape[0] # extract batch size if not X.is_cuda: tmpGrid = self.grid # Nx2 else: tmpGrid = self.grid.cuda() # Nx2 tmpGrid = tmpGrid.unsqueeze(0) tmpGrid = tmpGrid.expand(B, -1, -1) # BxNx2 # 1st folding f = torch.cat((tmpGrid, codeword), 2) # BxNx(K+2) f = self.Fold1.forward(f) # BxNx3 # 2nd folding f = torch.cat((f, codeword), 2) # BxNx(K+3) f = self.Fold2.forward(f) # BxNx3 return f