Beispiel #1
0
def get_dataset():
    tetris = [
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],  # chiral_shape_2
        [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
        [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)]
    ]  # zigzag
    tetris = torch.tensor(tetris, dtype=torch.get_default_dtype())
    labels = torch.tensor(
        [
            [+1, 0, 0, 0, 0, 0, 0],  # chiral_shape_1
            [-1, 0, 0, 0, 0, 0, 0],  # chiral_shape_2
            [0, 1, 0, 0, 0, 0, 0],  # square
            [0, 0, 1, 0, 0, 0, 0],  # line
            [0, 0, 0, 1, 0, 0, 0],  # corner
            [0, 0, 0, 0, 1, 0, 0],  # L
            [0, 0, 0, 0, 0, 1, 0],  # T
            [0, 0, 0, 0, 0, 0, 1],  # zigzag
        ],
        dtype=torch.get_default_dtype())

    # apply random rotation
    tetris = torch.einsum('ij,zaj->zai', o3.rand_rot(), tetris)

    return tetris, labels
Beispiel #2
0
 def test_rot_to_abc(self):
     with o3.torch_default_dtype(torch.float64):
         R = o3.rand_rot()
         abc = o3.rot_to_abc(R)
         R2 = o3.rot(*abc)
         d = (R - R2).norm() / R.norm()
         self.assertTrue(d < 1e-10, d)
Beispiel #3
0
    def __init__(self, Rs, act, n):
        '''
        map to a signal on SO3, apply the non linearity point wise and project back
        the signal on SO3 is the regular representation of SO3
        and we can apply a pointwise operation on these representations

        :param Rs: input representation
        :param act: activation function
        :param n: number of point on the sphere (the higher the more accurate)
        '''
        super().__init__()

        Rs = rs.simplify(Rs)
        mul0, _, _ = Rs[0]
        assert all(mul0 * (2 * l + 1) == mul for mul, l, _ in Rs)
        assert [l for _, l, _ in Rs] == list(range(len(Rs)))
        assert all(p == 0 for _, l, p in Rs)

        self.Rs_out = Rs

        x = [o3.rand_rot() for _ in range(n)]
        Z = torch.stack([
            torch.cat([
                o3.irr_repr(l, *o3.rot_to_abc(R)).flatten() * (2 * l + 1)**0.5
                for l in range(len(Rs))
            ]) for R in x
        ])  # [z, lmn]
        Z.div_(Z.shape[1]**0.5)
        self.register_buffer('Z', Z)
        self.act = act
Beispiel #4
0
def get_dataset():
    tetris = [[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
              [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],  # chiral_shape_2
              [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
              [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
              [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
              [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
              [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
              [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)]]  # zigzag
    tetris = torch.tensor(tetris, dtype=torch.get_default_dtype())
    labels = torch.arange(len(tetris))

    # apply random rotation
    tetris = torch.stack([torch.einsum("ij,nj->ni", (o3.rand_rot(), x)) for x in tetris])

    return tetris, labels