Exemplo n.º 1
0
def check_network_equivariance(network, Rs_in, Rs_out, vector_output: bool = False, batch: int = 2,
                               n_atoms: int = 6) -> bool:
    """
    from user func(input_data, alpha, beta, gamma, parity)
    from user func(output_data,  alpha, beta, gamma, parity)

    Checks rotation equivariance for a function. Network output transforms with the wigner d matrices.

    :param network: Often the forward pass of an object inheriting from torch.nn.Module
    :param Rs_in: Rs for the input of the network.
    :param Rs_out: Rs for the output of the network.
    :param vector_output: When True, the network outputs vectors which have already been converted to xyz basis.
    :param batch: Batch size of test case.
    :param n_atoms: Number of atoms of test case.
    :return: None
    """

    def check_parity(Rs):
        """Does Rs require a test of parity equivariance."""
        if any(p != 0 for _, _, p in normalizeRs(Rs)):
            return True
        else:
            return False

    # Make sure that all parameters are on the same device. Set that to calculation device.
    devices = [i.device for i in network.parameters()]
    device = devices[0]
    assert all([device == i for i in devices])

    Rs_in = normalizeRs(Rs_in)
    Rs_out = normalizeRs(Rs_out)
    parity = check_parity(Rs_in) or check_parity(Rs_out)

    abc = torch.randn(3, device=device)  # Rotation seed of euler angles.
    D_in = rep(Rs_in, *abc)
    geo_rotation_matrix = -rot(*abc) if parity else rot(*abc)  # Geometry is odd parity.
    D_out = rep(Rs_out, *abc)

    c = sum([mul * (2 * l + 1) for mul, l, _ in Rs_in])
    feat = torch.randn(batch, n_atoms, c, device=device)  # Transforms with wigner D matrix
    geo = torch.randn(batch, n_atoms, 3, device=device)  # Transforms with rotation matrix.

    F = network(feat, geo)
    if vector_output:
        RF = torch.einsum("ij,zkj->zki", geo_rotation_matrix, F)
    else:
        RF = torch.einsum("ij,zkj->zki", D_out, F)
    FR = network(feat @ D_in.t(), geo @ geo_rotation_matrix.t())  # [batch, feat, N]
    return (RF - FR).norm() < 10e-5 * RF.norm()
Exemplo n.º 2
0
def rotate(x, alpha):
    t = time_logging.start()
    y = x.cpu().detach().numpy()
    R = rot(alpha, 0, 0)
    y = rotate_scalar(y, R)
    x = x.new_tensor(y)
    time_logging.end("rotate", t)
    return x
Exemplo n.º 3
0
def rotation_from_orientations(t1, f1, t2, f2):
    from se3cnn.SO3 import x_to_alpha_beta, rot

    zero = t1.new_tensor(0)

    r_e_t1 = rot(*x_to_alpha_beta(t1), zero)
    r_e_t2 = rot(*x_to_alpha_beta(t2), zero)

    f1_e = r_e_t1.t() @ f1
    f2_e = r_e_t2.t() @ f2

    c = torch.atan2(f2_e[1], f2_e[0]) - torch.atan2(f1_e[1], f1_e[0])
    r_f1_f2 = rot(zero, zero, c)

    r = r_e_t2 @ r_f1_f2 @ r_e_t1.t()

    # t2 = r @ t1
    # f2 ~ r @ f1
    return r
Exemplo n.º 4
0
def rotate(x, alpha, beta, gamma):
    t = time_logging.start()
    y = x.cpu().detach().numpy()
    R = rot(alpha, beta, gamma)
    if x.ndimension() == 4:
        for i in range(y.shape[0]):
            y[i] = rotate_scalar(y[i], R)
    else:
        y = rotate_scalar(y, R)
    x = x.new_tensor(y)
    if x.ndimension() == 4 and x.size(0) == 3:
        rep = irr_repr(1, alpha, beta, gamma, x.dtype).to(x.device)
        x = torch.einsum("ij,jxyz->ixyz", (rep, x))
    time_logging.end("rotate", t)
    return x
Exemplo n.º 5
0
 def _get_random_affine_trafo(self):
     if self.rotate:
         alpha, beta, gamma = np.pi * np.array([2, 1, 2
                                                ]) * np.random.rand(3)
         aff = rot(alpha, beta, gamma)
     else:
         aff = np.eye(3)  # only non-homogeneous coord part
     fl = (-1)**np.random.randint(low=0, high=2) if self.flip else 1
     if self.scale is not None:
         sx, sy, sz = np.random.uniform(low=self.scale[0],
                                        high=self.scale[1],
                                        size=3)
     else:
         sx, sy, sz = 1
     aff[:, 0] *= sx * fl
     aff[:, 1] *= sy
     aff[:, 2] *= sz
     center = self.vol_shape / 2
     offset = center - center @ aff.T  # correct offset to apply trafo around center
     if self.translate:
         offset += np.random.uniform(low=-.5, high=.5, size=3)
     return partial(affine_transform, matrix=aff, offset=offset)
Exemplo n.º 6
0
def get_volumes(size=20, pad=8, rotate=False, rotate90=False):
    assert size >= 4
    tetris_tensorfields = [
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
        [(0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 0, 0)],  # chiral_shape_2
        [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
        [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],  # zigzag
    ]

    labels = np.arange(len(tetris_tensorfields))

    tetris_vox = []
    for shape in tetris_tensorfields:
        volume = np.zeros((4, 4, 4))
        for xi_coords in shape:
            volume[xi_coords] = 1

        volume = zoom(volume, size / 4, order=0)
        volume = np.pad(volume, pad, 'constant')

        if rotate:
            a, c = np.random.rand(2) * 2 * np.pi
            b = np.arccos(np.random.rand() * 2 - 1)
            volume = rotate_scalar(volume, rot(a, b, c))
        if rotate90:
            volume = rot_volume_90(volume)

        tetris_vox.append(volume[np.newaxis, ...])

    tetris_vox = np.stack(tetris_vox).astype(np.float32)

    return tetris_vox, labels
Exemplo n.º 7
0
def train_step(item, top, front, model, optimizer):
    from se3cnn.SO3 import rot

    model.train()

    abc = 15 * (torch.rand(3) * 2 - 1) / 180 * math.pi
    r = rot(*abc).to(item.device)
    tip = torch.cat(
        [torch.einsum("ij,...j->...i", (r, x)) for x in [top, front]], dim=1)
    tip = tip.view(tip.size(0), tip.size(1), 1, 1,
                   1).expand(tip.size(0), tip.size(1), item.size(2),
                             item.size(3), item.size(4))

    input = torch.cat([item, tip], dim=1)

    prediction = model(input)
    pred_top, pred_front = prediction[:, :3], prediction[:, 3:]

    overlap_top = overlap(pred_top, top)
    overlap_front = overlap(pred_front, front)
    overlap_orth = overlap(pred_top, pred_front)

    optimizer.zero_grad()
    (-overlap_top - overlap_front + overlap_orth.abs()).mean().backward()
    optimizer.step()

    a = torch.mean(
        torch.tensor([
            angle_from_rotation(
                rotation_from_orientations(top[i], front[i], pred_top[i],
                                           pred_front[i]))
            for i in range(len(item))
        ]))

    return overlap_top.mean().item(), overlap_front.mean().item(
    ), overlap_orth.abs().mean().item(), a.item()