Exemple #1
0
def tetris():
    pos = [
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],  # chiral_shape_2
        [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
        [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # L
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # T
        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],  # zigzag
    ]
    pos = torch.tensor(pos, dtype=torch.get_default_dtype())

    # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them
    labels = torch.tensor(
        [
            [+1, 0, 0, 0, 0, 0, 0],  # chiral_shape_1
            [-1, 0, 0, 0, 0, 0, 0],  # chiral_shape_2
            [0, 1, 0, 0, 0, 0, 0],  # square
            [0, 0, 1, 0, 0, 0, 0],  # line
            [0, 0, 0, 1, 0, 0, 0],  # corner
            [0, 0, 0, 0, 1, 0, 0],  # L
            [0, 0, 0, 0, 0, 1, 0],  # T
            [0, 0, 0, 0, 0, 0, 1],  # zigzag
        ],
        dtype=torch.get_default_dtype())

    # apply random rotation
    pos = torch.einsum('zij,zaj->zai', o3.rand_matrix(len(pos)), pos)

    # put in torch_geometric format
    dataset = [Data(pos=pos, x=torch.ones(4, 1)) for pos in pos]
    data = next(iter(DataLoader(dataset, batch_size=len(dataset))))

    return data, labels
Exemple #2
0
def test_xyz(float_tolerance):
    R = o3.rand_matrix(10)
    assert (R @ R.transpose(-1, -2) -
            torch.eye(3)).abs().max() < float_tolerance

    a, b, c = o3.matrix_to_angles(R)
    pos1 = o3.angles_to_xyz(a, b)
    pos2 = R @ torch.tensor([0, 1.0, 0])
    assert torch.allclose(pos1, pos2, atol=float_tolerance)

    a2, b2 = o3.xyz_to_angles(pos2)
    assert (a - a2).abs().max() < float_tolerance
    assert (b - b2).abs().max() < float_tolerance
Exemple #3
0
def test_conversions(float_tolerance):
    def wrap(f):
        def g(x):
            if isinstance(x, tuple):
                return f(*x)
            else:
                return f(x)

        return g

    def identity(x):
        return x

    conv = [
        [
            identity,
            wrap(o3.angles_to_matrix),
            wrap(o3.angles_to_axis_angle),
            wrap(o3.angles_to_quaternion)
        ],
        [
            wrap(o3.matrix_to_angles), identity,
            wrap(o3.matrix_to_axis_angle),
            wrap(o3.matrix_to_quaternion)
        ],
        [
            wrap(o3.axis_angle_to_angles),
            wrap(o3.axis_angle_to_matrix), identity,
            wrap(o3.axis_angle_to_quaternion)
        ],
        [
            wrap(o3.quaternion_to_angles),
            wrap(o3.quaternion_to_matrix),
            wrap(o3.quaternion_to_axis_angle), identity
        ],
    ]

    R1 = o3.rand_matrix(100)
    path = [1, 2, 3, 0, 2, 0, 3, 1, 3, 2, 1, 0, 1]

    g = R1
    for i, j in zip(path, path[1:]):
        g = conv[i][j](g)
    R2 = g

    assert (R1 - R2).abs().median() < float_tolerance
Exemple #4
0
def test():
    torch.set_default_dtype(torch.float64)

    pos = torch.tensor([
        [0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [0.0, 0.0, 1.5],
    ])

    # atom type
    z = torch.tensor([0, 1, 2, 2])

    dataset = [Data(pos=pos @ R.T, z=z) for R in o3.rand_matrix(10)]
    data = next(iter(DataLoader(dataset, batch_size=len(dataset))))

    f = InvariantPolynomial("0e+0o", num_z=3, lmax=3)

    out = f(data)

    # expect invariant output
    assert out.std(0).max() < 1e-5
Exemple #5
0
def equivariance_error(func,
                       args_in,
                       irreps_in=None,
                       irreps_out=None,
                       ntrials=1,
                       do_parity=True,
                       do_translation=True):
    r"""Get the maximum equivariance error for ``func`` over ``ntrials``

    Each trial randomizes the equivariant transformation tested.

    Parameters
    ----------
    func : callable
        the function to test
    args_in : list
        the original inputs to pass to ``func``.
    irreps_in : list of `e3nn.o3.Irreps` or `e3nn.o3.Irreps`
        the input irreps for each of the arguments in ``args_in``. If left as the default of ``None``, ``get_io_irreps`` will be used to try to infer them. If a sequence is provided, valid elements are also the string ``'cartesian'``, which denotes that the corresponding input should be dealt with as cartesian points in 3D, and ``None``, which indicates that the argument should not be transformed.
    irreps_out : list of `e3nn.o3.Irreps` or `e3nn.o3.Irreps`
        the out irreps for each of the return values of ``func``. Accepts similar values to ``irreps_in``.
    ntrials : int
        run this many trials with random transforms
    do_parity : bool
        whether to test parity
    do_translation : bool
        whether to test translation for ``'cartesian'`` inputs

    Returns
    -------
    dictionary mapping tuples ``(parity_k, did_translate)`` to errors
    """
    irreps_in, irreps_out = _get_io_irreps(func,
                                           irreps_in=irreps_in,
                                           irreps_out=irreps_out)

    if do_parity:
        parity_ks = [0, 1]
    else:
        parity_ks = [0]

    if 'cartesian_points' not in irreps_in:
        # There's nothing to translate
        do_translation = False
    if do_translation:
        do_translation = [False, True]
    else:
        do_translation = [False]

    tests = itertools.product(parity_ks, do_translation)

    neg_inf = -float("Inf")
    biggest_errs = {}

    for trial in range(ntrials):
        for this_test in tests:
            parity_k, this_do_translate = this_test
            # Build a rotation matrix for point data
            rot_mat = o3.rand_matrix()
            # add parity
            rot_mat *= (-1)**parity_k
            # build translation
            translation = 10 * torch.randn(
                1, 3, dtype=rot_mat.dtype) if this_do_translate else 0.

            # Evaluate the function on rotated arguments:
            rot_args = _transform(args_in, irreps_in, rot_mat, translation)
            x1 = func(*rot_args)

            # Evaluate the function on the arguments, then apply group action:
            x2 = func(*args_in)

            # Deal with output shapes
            assert type(x1) == type(
                x2), f"Inconsistant return types {type(x1)} and {type(x2)}"  # pylint: disable=unidiomatic-typecheck
            if isinstance(x1, torch.Tensor):
                # Make sequences
                x1 = [x1]
                x2 = [x2]
            elif isinstance(x1, (list, tuple)):
                # They're already tuples
                x1 = list(x1)
                x2 = list(x2)
            else:
                raise TypeError(
                    f"equivariance_error cannot handle output type {type(x1)}")
            assert len(x1) == len(x2)
            assert len(x1) == len(irreps_out)

            # apply the group action to x2
            x2 = _transform(x2, irreps_out, rot_mat, translation)

            error = max((a - b).abs().max() for a, b in zip(x1, x2))

            if error > biggest_errs.get(this_test, neg_inf):
                biggest_errs[this_test] = error

    return biggest_errs