Example #1
0
    def __call__(self, data):
        assert data.face is not None
        assert data.pos is not None
        assert data.pos.size(-1) == 3
        assert data.face.size(0) == 3

        mesh = to_trimesh(data)
        edge_index, edge_attr = getGeometricEdgeFeatures(
            mesh,
            pp_features=self.pp_features,
            triangle_features=self.triangle_features,
            n_components=self.n_components)

        if data.x is not None:
            device = data.x.device
        else:
            device = data.edge_index.device

        data.edge_attr = torch.tensor(edge_attr,
                                      dtype=torch.float32).to(device)

        if self.assign_edges:
            data.edge_index = torch.tensor(edge_index.T,
                                           dtype=torch.int64).to(device)

        return data
Example #2
0
def test_trimesh():
    pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],
                       dtype=torch.float)
    face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()

    data = Data(pos=pos, face=face)
    mesh = to_trimesh(data)
    data = from_trimesh(mesh)

    assert pos.tolist() == data.pos.tolist()
    assert face.tolist() == data.face.tolist()
def test_trimesh():
    pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],
                       dtype=torch.float)
    face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()

    data = Data(pos=pos, face=face)
    mesh = to_trimesh(data)
    data = from_trimesh(mesh)

    perm = (data.pos * torch.tensor([[1.0, 2.0, 3.0]])).sum(dim=-1).argsort()
    assert pos.tolist() == data.pos[perm].tolist()
    assert face.tolist() == perm[data.face].tolist()
Example #4
0
    train_dataloader = FAUSTDataLoader(train_dataset, batch_size=batch_size)
    val_dataloader = FAUSTDataLoader(val_dataset, batch_size=batch_size)
elif args.datasets == 'full_faust':
    dataset = FullFAUST('.')
    # TODO: Remove hardcoding of test person
    train_dataset, val_dataset = split_faust_by_person(dataset, [1])
    train_dataloader = FAUSTDataLoader(train_dataset, batch_size=batch_size)
    val_dataloader = FAUSTDataLoader(val_dataset, batch_size=batch_size)
else:
    raise NotImplementedError('')

# for DFAUST splitting. 9:1 ratio, with seed from above

print('Template')
template = train_dataset[0]
pv_template = pv.wrap(to_trimesh(template))
print(template)
print()

print('Initialise writer')
writer = writer.MeshWriter(args, pv_template)

shape = 6890
in_channels = 3
# For 1D conv decoder
kernel_size = 5031
# max(
#     template.face.T.max(axis=1).values - template.face.T.min(axis=1).values
# )
padding = kernel_size // 2