Ejemplo n.º 1
0
def check_rotation_parity(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    C = partial(Convolution, K)
    Rs_in = [(1, 0, +1)]
    f = GatedBlockParity(Operation=C,
                         Rs_in=Rs_in,
                         Rs_scalars=[(4, 0, +1)],
                         act_scalars=[(-1, relu)],
                         Rs_gates=[(8, 0, +1)],
                         act_gates=[(-1, tanh)],
                         Rs_nonscalars=[(4, 1, -1), (4, 2, +1)])
    Rs_out = f.Rs_out

    # Setup the data. The geometry, input features, and output features must all rotate and observe parity.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = -rot(
        *abc
    )  # Negative because geometry has odd parity. i.e. improper rotation.
    D_in = rep(Rs_in, *abc, parity=1)
    D_out = rep(Rs_out, *abc, parity=1)

    c = sum([mul * (2 * l + 1) for mul, l, _ in Rs_in])
    feat = torch.randn(batch, n_atoms,
                       c)  # Transforms with wigner D matrix and parity.
    geo = torch.randn(batch, n_atoms,
                      3)  # Transforms with rotation matrix and parity.

    # Test equivariance.
    F = f(feat, geo)
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = f(feat @ D_in.t(), geo @ rot_geo.t())
    return (RF - FR).norm() < 10e-5 * RF.norm()
Ejemplo n.º 2
0
    def test3(self):
        """Test rotation equivariance on GatedBlock and dependencies."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            K = partial(Kernel, RadialModel=ConstantRadialModel)
            C = partial(Convolution, K)

            f = GatedBlock(partial(C, Rs_in),
                           Rs_out,
                           scalar_activation=sigmoid,
                           gate_activation=sigmoid)

            abc = torch.randn(3)
            rot_geo = rot(*abc)
            D_in = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)])

            fea = torch.randn(1, 4, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo)))
            x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)),
                   torch.einsum("ij,zaj->zai", rot_geo, geo))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Ejemplo n.º 3
0
def check_rotation(batch: int = 10, n_atoms: int = 25):
    # Setup the network.
    K = partial(Kernel, RadialModel=ConstantRadialModel)
    C = partial(Convolution, K)
    Rs_in = [(1, 0), (1, 1)]
    Rs_out = [(1, 0), (1, 1), (1, 2)]
    f = GatedBlock(
        partial(C, Rs_in),
        Rs_out,
        scalar_activation=sigmoid,
        gate_activation=absolute,
    )

    # Setup the data. The geometry, input features, and output features must all rotate.
    abc = torch.randn(3)  # Rotation seed of euler angles.
    rot_geo = rot(*abc)
    D_in = rep(Rs_in, *abc)
    D_out = rep(Rs_out, *abc)

    c = sum([mul * (2 * l + 1) for mul, l in Rs_in])
    feat = torch.randn(batch, n_atoms, c)  # Transforms with wigner D matrix
    geo = torch.randn(batch, n_atoms, 3)  # Transforms with rotation matrix.

    # Test equivariance.
    F = f(feat, geo)
    RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = f(feat @ D_in.t(), geo @ rot_geo.t())
    return (RF - FR).norm() < 10e-5 * RF.norm()
Ejemplo n.º 4
0
def rotation_from_orientations(t1, f1, t2, f2):
    from e3nn.SO3 import xyz_to_angles, rot

    zero = t1.new_tensor(0)

    r_e_t1 = rot(*xyz_to_angles(t1), zero)
    r_e_t2 = rot(*xyz_to_angles(t2), zero)

    f1_e = r_e_t1.t() @ f1
    f2_e = r_e_t2.t() @ f2

    c = torch.atan2(f2_e[1], f2_e[0]) - torch.atan2(f1_e[1], f1_e[0])
    r_f1_f2 = rot(zero, zero, c)

    r = r_e_t2 @ r_f1_f2 @ r_e_t1.t()

    # t2 = r @ t1
    # f2 ~ r @ f1
    return r
Ejemplo n.º 5
0
def rotate(x, alpha, beta, gamma):
    t = time_logging.start()
    y = x.cpu().detach().numpy()
    R = rot(alpha, beta, gamma)
    if x.ndimension() == 4:
        for i in range(y.shape[0]):
            y[i] = rotate_scalar(y[i], R)
    else:
        y = rotate_scalar(y, R)
    x = x.new_tensor(y)
    if x.ndimension() == 4 and x.size(0) == 3:
        rep = irr_repr(1, alpha, beta, gamma, x.dtype).to(x.device)
        x = torch.einsum("ij,jxyz->ixyz", (rep, x))
    time_logging.end("rotate", t)
    return x
Ejemplo n.º 6
0
    def test2(self):
        """Test rotation equivariance on Kernel."""
        with torch_default_dtype(torch.float64):
            Rs_in = [(2, 0), (0, 1), (2, 2)]
            Rs_out = [(2, 0), (2, 1), (2, 2)]

            k = Kernel(Rs_in, Rs_out, ConstantRadialModel)
            r = torch.randn(3)

            abc = torch.randn(3)
            D_in = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = direct_sum(
                *[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)])

            W1 = D_out @ k(r)  # [i, j]
            W2 = k(rot(*abc) @ r) @ D_in  # [i, j]
            self.assertLess((W1 - W2).norm(), 10e-5 * W1.norm())
Ejemplo n.º 7
0
 def _get_random_affine_trafo(self):
     if self.rotate:
         alpha,beta,gamma = np.pi*np.array([2,1,2])*np.random.rand(3)
         aff = rot(alpha,beta,gamma)
     else:
         aff = np.eye(3) # only non-homogeneous coord part
     fl = (-1)**np.random.randint(low=0, high=2) if self.flip else 1
     if self.scale is not None:
         sx,sy,sz = np.random.uniform(low=self.scale[0], high=self.scale[1], size=3)
     else:
         sx,sy,sz = 1
     aff[:,0] *= sx*fl
     aff[:,1] *= sy
     aff[:,2] *= sz
     center = self.vol_shape/2
     offset = center - [email protected] # correct offset to apply trafo around center
     if self.translate:
         offset += np.random.uniform(low=-.5, high=.5, size=3)
     return partial(affine_transform, matrix=aff, offset=offset)
Ejemplo n.º 8
0
    def test1(self):
        with torch_default_dtype(torch.float64):
            Rs_in = [(3, 0), (3, 1), (2, 0), (1, 2)]
            Rs_out = [(3, 0), (3, 1), (1, 2), (3, 0)]

            K = partial(Kernel, RadialModel=ConstantRadialModel)
            C = partial(Convolution, K)
            f = GatedBlock(partial(C, Rs_in), Rs_out, rescaled_act.Softplus(beta=5), rescaled_act.sigmoid)

            abc = torch.randn(3)
            D_in = direct_sum(*[irr_repr(l, *abc) for mul, l in Rs_in for _ in range(mul)])
            D_out = direct_sum(*[irr_repr(l, *abc) for mul, l in Rs_out for _ in range(mul)])

            x = torch.randn(1, 5, sum(mul * (2 * l + 1) for mul, l in Rs_in))
            geo = torch.randn(1, 5, 3)

            rx = torch.einsum("ij,zaj->zai", (D_in, x))
            rgeo = geo @ rot(*abc).t()

            y = f(x, geo, dim=2)
            ry = torch.einsum("ij,zaj->zai", (D_out, y))

            self.assertLess((f(rx, rgeo) - ry).norm(), 1e-10 * ry.norm())
Ejemplo n.º 9
0
    def test1(self):
        """Test irr_repr and clebsch_gordan equivariance."""
        with torch_default_dtype(torch.float64):
            l_in = 3
            l_out = 2

            for l_f in range(abs(l_in - l_out), l_in + l_out + 1):
                r = torch.randn(100, 3)
                Q = clebsch_gordan(l_out, l_in, l_f)

                abc = torch.randn(3)
                D_in = irr_repr(l_in, *abc)
                D_out = irr_repr(l_out, *abc)

                Y = spherical_harmonics_xyz(l_f, r @ rot(*abc).t())
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W1 = torch.einsum("zij,jk->zik", (W, D_in))

                Y = spherical_harmonics_xyz(l_f, r)
                W = torch.einsum("ijk,kz->zij", (Q, Y))
                W2 = torch.einsum("ij,zjk->zik", (D_out, W))

                self.assertLess((W1 - W2).norm(), 1e-5 * W.norm(), l_f)
Ejemplo n.º 10
0
    def test6(self):
        """Test parity and rotation equivariance on GatedBlockParity and dependencies."""
        with torch_default_dtype(torch.float64):
            mul = 2
            Rs_in = [(mul, l, p) for l in range(6) for p in [-1, 1]]

            K = partial(Kernel, RadialModel=ConstantRadialModel)
            C = partial(Convolution, K)

            scalars = [(mul, 0, +1), (mul, 0, -1)], [(mul, relu),
                                                     (mul, absolute)]
            rs_nonscalars = [(mul, 1, +1), (mul, 1, -1), (mul, 2, +1),
                             (mul, 2, -1), (mul, 3, +1), (mul, 3, -1)]
            n = 3 * mul
            gates = [(n, 0, +1), (n, 0, -1)], [(n, sigmoid), (n, tanh)]

            f = GatedBlockParity(C, Rs_in, *scalars, *gates, rs_nonscalars)

            abc = torch.randn(3)
            rot_geo = -rot(*abc)
            D_in = direct_sum(*[
                p * irr_repr(l, *abc) for mul, l, p in Rs_in
                for _ in range(mul)
            ])
            D_out = direct_sum(*[
                p * irr_repr(l, *abc) for mul, l, p in f.Rs_out
                for _ in range(mul)
            ])

            fea = torch.randn(1, 4,
                              sum(mul * (2 * l + 1) for mul, l, p in Rs_in))
            geo = torch.randn(1, 4, 3)

            x1 = torch.einsum("ij,zaj->zai", (D_out, f(fea, geo)))
            x2 = f(torch.einsum("ij,zaj->zai", (D_in, fea)),
                   torch.einsum("ij,zaj->zai", rot_geo, geo))
            self.assertLess((x1 - x2).norm(), 10e-5 * x1.norm())
Ejemplo n.º 11
0
def get_volumes(size=20, pad=8, rotate=False, rotate90=False):
    assert size >= 4
    tetris_tensorfields = [
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
        [(0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 0, 0)],  # chiral_shape_2
        [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
        [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],  # zigzag
    ]

    labels = np.arange(len(tetris_tensorfields))

    tetris_vox = []
    for shape in tetris_tensorfields:
        volume = np.zeros((4, 4, 4))
        for xi_coords in shape:
            volume[xi_coords] = 1

        volume = zoom(volume, size / 4, order=0)
        volume = np.pad(volume, pad, 'constant')

        if rotate:
            a, c = np.random.rand(2) * 2 * np.pi
            b = np.arccos(np.random.rand() * 2 - 1)
            volume = rotate_scalar(volume, rot(a, b, c))
        if rotate90:
            volume = rot_volume_90(volume)

        tetris_vox.append(volume[np.newaxis, ...])

    tetris_vox = np.stack(tetris_vox).astype(np.float32)

    return tetris_vox, labels
Ejemplo n.º 12
0
def train_step(item, top, front, model, optimizer):
    from e3nn.SO3 import rot

    model.train()

    abc = 15 * (torch.rand(3) * 2 - 1) / 180 * math.pi
    r = rot(*abc).to(item.device)
    tip = torch.cat(
        [torch.einsum("ij,...j->...i", (r, x)) for x in [top, front]], dim=1)
    tip = tip.view(tip.size(0), tip.size(1), 1, 1,
                   1).expand(tip.size(0), tip.size(1), item.size(2),
                             item.size(3), item.size(4))

    input = torch.cat([item, tip], dim=1)

    prediction = model(input)
    pred_top, pred_front = prediction[:, :3], prediction[:, 3:]

    overlap_top = overlap(pred_top, top)
    overlap_front = overlap(pred_front, front)
    overlap_orth = overlap(pred_top, pred_front)

    optimizer.zero_grad()
    (-overlap_top - overlap_front + overlap_orth.abs()).mean().backward()
    optimizer.step()

    a = torch.mean(
        torch.tensor([
            angle_from_rotation(
                rotation_from_orientations(top[i], front[i], pred_top[i],
                                           pred_front[i]))
            for i in range(len(item))
        ]))

    return overlap_top.mean().item(), overlap_front.mean().item(
    ), overlap_orth.abs().mean().item(), a.item()