Ejemplo n.º 1
0
def transform_setup(graph_u=False,
                    graph_gcn=False,
                    rotation=180,
                    samplePoints=1024,
                    mesh=False,
                    node_translation=0.01):
    if not graph_u and not graph_gcn:
        # Default transformation for scale noralization, centering, point sampling and rotating
        pretransform = T.Compose([T.NormalizeScale(), T.Center()])
        transform = T.Compose([
            T.SamplePoints(samplePoints),
            T.RandomRotate(rotation[0], rotation[1])
        ])
        print("pointnet rotation {}".format(rotation))
    elif graph_u:
        pretransform = T.Compose([T.NormalizeScale(), T.Center()])
        transform = T.Compose([
            T.NormalizeScale(),
            T.Center(),
            T.SamplePoints(samplePoints, True, True),
            T.RandomRotate(rotation[0], rotation[1]),
            T.KNNGraph(k=graph_u)
        ])
    elif graph_gcn:

        pretransform = T.Compose([T.NormalizeScale(), T.Center()])

        if mesh:
            if mesh == "extraFeatures":
                transform = T.Compose([
                    T.RandomRotate(rotation[0], rotation[1]),
                    T.GenerateMeshNormals(),
                    T.FaceToEdge(True),
                    T.Distance(norm=True),
                    T.TargetIndegree(cat=True)
                ])  # ,
            else:
                transform = T.Compose([
                    T.RandomRotate(rotation[0], rotation[1]),
                    T.GenerateMeshNormals(),
                    T.FaceToEdge(True),
                    T.Distance(norm=True),
                    T.TargetIndegree(cat=True)
                ])
        else:
            transform = T.Compose([
                T.SamplePoints(samplePoints, True, True),
                T.KNNGraph(k=graph_gcn),
                T.Distance(norm=True)
            ])
            print("no mesh")
        print("Rotation {}".format(rotation))
        print("Meshing {}".format(mesh))

    else:
        print('no transfom')

    return transform, pretransform
Ejemplo n.º 2
0
    def __init__(self, 
        root:str, 
        device:torch.device=torch.device("cpu"),
        train:bool=True, 
        test:bool=True,
        transform_data:bool=True):

        self.url = 'https://drive.google.com/file/d/1dp4sMvZ8cmIIITE-qj6zYpZb0-v-4Kgf/view?usp=sharing'
        self.categories = ["big_cats","cows","dogs","hippos","horses"]

        # center each mesh into its centroid
        pre_transform = transforms.Center()

        # transform
        if transform_data:
            # rotate and move
            transform = transforms.Compose([
                Move(mean=[0,0,0], std=[0.05,0.05,0.05]), 
                Rotate(dims=[0,1,2]), 
                ToDevice(device)])
        else:
            transform=ToDevice(device)

        super().__init__(root=root, transform=transform, pre_transform=pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        self.downscaler = dscale.Downscaler(
            filename=join(self.processed_dir,"ds"), mesh=self.get(0), factor=2)


        if train and not test:
            self.data, self.slices = self.collate([self.get(i) for i in range(len(self)) if self.get(i).pose < 16])
        elif not train and test:
            self.data, self.slices = self.collate([self.get(i) for i in range(len(self)) if self.get(i).pose >= 16])
Ejemplo n.º 3
0
    def __init__(self,
                 root: str,
                 device: torch.device = torch.device("cpu"),
                 train: bool = True,
                 test: bool = True,
                 transform_data: bool = True):
        self.url = 'http://www.cs.cf.ac.uk/shaperetrieval/shrec14/'

        if transform_data:
            # rotate and move
            transform = transforms.Compose([
                transforms.Center(),
                #                 transforms.RandomScale((0.8,1.2)),
                Rotate(dims=[1]),
                Move(mean=[0, 0, 0], std=[0.05, 0.05, 0.05]),
                transforms.RandomTranslate(0.01),
                ToDevice(device)
            ])
        else:
            transform = ToDevice(device)

        # center each mesh into its centroid
        super().__init__(root=root,
                         transform=transform,
                         pre_transform=transforms.Center())

        self.data, self.slices = torch.load(self.processed_paths[0])

        testset_slice, trainset_slice = list(range(0, 40)) + list(
            range(200, 240)), list(range(40, 200)) + list(range(240, 400))
        if train and not test:
            self.data, self.slices = self.collate(
                [self[i] for i in trainset_slice])

        elif not train and test:
            self.data, self.slices = self.collate(
                [self[i] for i in testset_slice])
Ejemplo n.º 4
0
def test_compose():
    transform = T.Compose([T.Center(), T.AddSelfLoops()])
    assert str(transform) == ('Compose([\n'
                              '  Center(),\n'
                              '  AddSelfLoops()\n'
                              '])')

    pos = torch.Tensor([[0, 0], [2, 0], [4, 0]])
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

    data = Data(edge_index=edge_index, pos=pos)
    data = transform(data)
    assert len(data) == 2
    assert data.pos.tolist() == [[-2, 0], [0, 0], [2, 0]]
    assert data.edge_index.size() == (2, 7)
Ejemplo n.º 5
0
def test_compose():
    transform = T.Compose([T.Center(), T.AddSelfLoops()])
    assert transform.__repr__() == ('Compose([\n'
                                    '    Center(),\n'
                                    '    AddSelfLoops(),\n'
                                    '])')

    pos = torch.tensor([[0, 0], [2, 0], [4, 0]], dtype=torch.float)
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    data = Data(edge_index=edge_index, pos=pos)

    out = transform(data)
    assert out.pos.tolist() == [[-2, 0], [0, 0], [2, 0]]
    assert out.edge_index.tolist() == [[0, 0, 1, 1, 1, 2, 2],
                                       [0, 1, 0, 1, 2, 1, 2]]
Ejemplo n.º 6
0
import torch_geometric
#from torch_geometric.transforms import GridSampling
import torch_geometric.transforms as T
from human_corres.data import ImgData, ImgBatch
from human_corres.config import PATH_TO_SURREAL
import human_corres.transforms as H

num_views = 20
IDlist = np.stack([
    np.arange(100000 * num_views),
    np.arange(115000 * num_views, 215000 * num_views)
],
                  axis=0)
num_test = 5000
DefaultTransform = T.Compose([
    T.Center(),
    T.RandomRotate(30, axis=0),
    T.RandomRotate(30, axis=1),
    T.RandomRotate(30, axis=2),
])


class SurrealFEPts5k(Dataset):
    """Surreal 3D points for Feature Extraction (FE).
  Samples a fixed number of points.

  Output: dictionary with keys {points3d, correspondence}
  Data Format:
    points3d: [num_points, 3] real numbers.
    correspondence: [num_points] integers in range [6890].
  """
Ejemplo n.º 7
0
 def __init__(self, center):
     self.is_center = center
     if center:
         self.center = T.Center()
     else:
         self.center = None
import scipy.io as sio
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
from human_corres.utils import helper
from human_corres.data import Data
import torch_geometric
import torch_geometric.transforms as T
import human_corres as hc
import human_corres.transforms as H

num_views = 20
IDlist = np.arange(5000 * num_views)
num_test = 50
TrainTransform = T.Compose([
    T.Center(),
    T.RandomRotate(30, axis=0),
    T.RandomRotate(30, axis=1),
    T.RandomRotate(30, axis=2),
    H.GridSampling(0.01)
])
TestTransform = T.Compose([T.Center(), H.GridSampling(0.01)])


class DGFSurrealFEPts(Dataset):
    """Surreal 3D points for Feature Extraction (FE).

  Output: dictionary with keys {points3d, correspondence}
  Data Format:
    points3d: [num_points, 3] real numbers.
    correspondence: [num_points] integers in range [6890].
def test(model, loader, args, epoch, REG):
    model.train()
    handler = MultiGPUHandler(loader, args, training=False, reg=True)

    CORRES = REG.replace('REG_', 'CORRES_')

    start = 0
    #import ipdb; ipdb.set_trace()
    with progressbar.ProgressBar(max_value=len(loader),
                                 widgets=handler.widgets) as bar:
        for i, data_list in enumerate(loader):
            if (start % 100 != 0) and ((start + 1) % 100 != 0) and (
                (start + 2) % 100 != 0):
                start += len(data_list)
                continue
            #print(data_list)
            centered = [T.Center()(data).to(args.device) for data in data_list]
            saved_pos = [data.pos.cpu().numpy() for data in centered]
            saved_y = [data.y.cpu().numpy() for data in centered]
            saved_pos = [
                torch.as_tensor(pos, dtype=torch.float).to(args.device)
                for pos in saved_pos
            ]
            saved_y = [
                torch.as_tensor(y, dtype=torch.float).to(args.device)
                for y in saved_y
            ]
            sampled = [
                H.GridSampling(0.01)(data).to(args.device) for data in centered
            ]
            out_dict = model(sampled)
            corres = correspondence(out_dict['feats'], handler.template_feats)
            pred_before_reg = []
            pred_after_reg = []
            offset = 0
            x_before_reg = handler.template_points[corres, :]
            x_after_reg = out_dict['x_out0']

            for idx, data in enumerate(sampled):
                length = data.pos.shape[0]
                pred0 = Prediction(data.pos,
                                   x_before_reg[offset:(offset + length)])
                pred1 = Prediction(data.pos,
                                   x_after_reg[offset:(offset + length)])
                pred0 = pred0.knn_interpolate(saved_pos[idx], 3)
                pred1 = pred1.knn_interpolate(saved_pos[idx], 3)
                offset += length
                #error_smpl = np.loadtxt(CORRES.replace('.reg', '.errors').format(start+idx)).reshape(-1, 2)
                pred0.evaluate_errors(saved_y[idx][:, -3:])
                pred1.evaluate_errors(saved_y[idx][:, -3:])
                pred0.save_to_mat(CORRES.format(start + idx))
                pred1.save_to_mat(REG.format(start + idx))
                #print(errors0.mean(), errors1.mean(), error_smpl.mean(0))
                #errors = torch.stack([errors0, errors1], dim=-1)
                #np.savetxt(CORRES.format(start+idx), errors.cpu().numpy(), fmt='%.6f %.6f')
            start += len(data_list)
            #if (i % 10 == 0) or (i == len(loader)-1):
            #  handler.visualize(bar)

    torch.cuda.empty_cache()
    return {}