コード例 #1
0
def test_unique_atom_pairs(device):
    """Tests the 'unique_atom_pairs' helper function."""
    geom = Geometry(atomic_numbers_data(device, True),
                    positions_data(device, True))
    ref = torch.tensor(
        [[1, 1], [6, 1], [8, 1], [1, 6], [6, 6],
         [8, 6], [1, 8], [6, 8], [8, 8]], device=device)
    check = (unique_atom_pairs(geom) == ref).all()
    assert check, "unique_atom_pairs returned an unexpected result"
コード例 #2
0
def geometry_basic_helper(device, positions, atomic_numbers):
    """Function to reduce code duplication when testing basic functionality."""
    # Pack the reference data, if multiple systems provided
    batch = isinstance(atomic_numbers, list)
    if batch:
        atomic_numbers_ref = pack(atomic_numbers)
        positions_ref = pack(positions)
        positions_angstrom = [i / length_units['angstrom'] for i in positions]
    else:
        atomic_numbers_ref = atomic_numbers
        positions_ref = positions
        positions_angstrom = positions / length_units['angstrom']

    # Check 1: Ensure the geometry entity is correct constructed
    geom_1 = Geometry(atomic_numbers, positions)
    check_1 = (torch.allclose(geom_1.atomic_numbers, atomic_numbers_ref)
               and torch.allclose(geom_1.positions, positions_ref))
    assert check_1, 'Geometry was not instantiated correctly'

    # Check 2: Check unit conversion proceeds as anticipated.
    geom_2 = Geometry(atomic_numbers, positions_angstrom, units='angstrom')
    check_2 = torch.allclose(geom_1.positions, geom_2.positions)
    assert check_2, 'Geometry failed to correctly convert length units'

    # Check 3: Check that __repr__ does not crash when called. No assert is
    # needed here as a failure will result in an exception being raised.
    _t = repr(geom_1)

    # Test with a larger number of systems to ensure the string gets truncated.
    # This is only applicable to batched Geometry instances.
    if batch:
        geom_3 = Geometry([atomic_numbers[0] for _ in range(10)],
                          [positions[0] for _ in range(10)])
        _t2 = repr(geom_3)
        check_3 = '...' in _t2
        assert check_3, 'String representation was not correctly truncated'

    # Check 4: Verify that the `.chemical_symbols` returns the correct value
    check_4 = all([chemical_symbols[int(j)] == i if isinstance(i, str)
                   else [chemical_symbols[int(k)] for k in j] == i
                   for i, j in zip(geom_1.chemical_symbols, atomic_numbers)])
    assert check_4, 'The ".chemical_symbols" property is incorrect'

    # Check 5: Test the device on which the Geometry's tensor are located
    # can be changed via the `.to()` method. Note that this check will only
    # be performed if a cuda device is present.
    if torch.cuda.device_count():
        # Select a device to move to
        new_device = {'cuda': torch.device('cpu'),
                      'cpu': torch.device('cuda:0')}[device.type]
        geom_1.to(new_device)
        check_5 = (geom_1.atomic_numbers.device == new_device
                   and geom_1.positions.device == new_device)

        assert check_5, '".to" method failed to set the correct device'
コード例 #3
0
def geometry_distance_vectors_helper(atomic_numbers, positions):
    """Function to reduce code duplication when checking .distance_vectors."""
    geom = Geometry(atomic_numbers, positions)

    # Check 1: Calculate distance vector tolerance
    if isinstance(positions, torch.Tensor):
        ref_d_vec = positions.unsqueeze(1) - positions
    else:
        ref_d_vec = pack([i.unsqueeze(1) - i for i in positions])
    d_vec = geom.distance_vectors
    check_1 = torch.allclose(d_vec, ref_d_vec)
    assert check_1, 'Distance vectors are outside of tolerance thresholds'

    # Check 2: Device persistence check
    check_2 = d_vec.device == geom.positions.device
    assert check_2, 'Distance vectors were not returned on the correct device'
コード例 #4
0
def geometry_hdf5_helper(path, atomic_numbers, positions):
    """Function to reduce code duplication when testing the HDF5 functionality."""
    # Ensure any test hdf5 database is erased before running
    if os.path.exists(path):
        os.remove(path)

    # Pack the reference data, if multiple systems provided
    batch = isinstance(atomic_numbers, list)
    atomic_numbers_ref = pack(atomic_numbers) if batch else atomic_numbers
    positions_ref = pack(positions) if batch else positions

    # Construct a geometry instance
    geom_1 = Geometry(atomic_numbers, positions)

    # Infer target device
    device = geom_1.positions.device

    # Open the database
    with h5py.File(path, 'w') as db:
        # Check 1: Write to the database and check that the written data
        # matches the reference data.
        geom_1.to_hdf5(db)
        check_1 = (np.allclose(db['atomic_numbers'][()], atomic_numbers_ref.sft())
                   and np.allclose(db['positions'][()], positions_ref.sft()))
        assert check_1, 'Geometry not saved the database correctly'

        # Check 2: Ensure geometries are correctly constructed from hdf5 data
        geom_2 = Geometry.from_hdf5(db, device=device)
        check_2 = (torch.allclose(geom_2.positions, geom_1.positions)
                   and torch.allclose(geom_2.atomic_numbers, geom_1.atomic_numbers))
        assert check_2, 'Geometry could not be loaded from hdf5 data'

        # Check 3: Make sure that the tensors were placed on the correct device
        check_3 = (geom_2.positions.device == device
                   and geom_2.atomic_numbers.device == device)
        assert check_3, 'Tensors not placed on the correct device'

    # If this is a batch test then repeat test 2 but pass in a list of HDF5
    # groups rather than one batch HDF5 group.
    if batch:
        os.remove(path)
        with h5py.File(path, 'w') as db:
            for n, (an, pos) in enumerate(zip(atomic_numbers, positions)):
                Geometry(an, pos).to_hdf5(db.create_group(f'geom_{n + 1}'))
            geom_3 = Geometry.from_hdf5([db[f'geom_{i}'] for i in range(1, 4)])
            check_4 = torch.allclose(geom_3.positions.to(device), geom_1.positions)
            assert check_4, 'Instance could not be loaded from hdf5 data (batch)'

    # Remove the test database
    os.remove(path)
コード例 #5
0
def test_geometry_from_ase_atoms_single(device):
    """Check single system instances can be instantiated from ase.Atoms objects."""

    # Create an ase.Atoms object
    atoms = molecule('CH4')

    # Check 1: Ensure that the from_ase_atoms method correctly constructs
    # a geometry instance. This includes the unit conversion operation.
    geom_1 = Geometry.from_ase_atoms(atoms, device=device)
    check_1 = np.allclose(geom_1.positions.sft(), atoms.positions * length_units['angstrom'])

    assert check_1, 'from_ase_atoms did not correctly parse the positions'

    # Check 2: Check the tensors were placed on the correct device
    check_2 = (geom_1.positions.device == device
               and geom_1.atomic_numbers.device == device)

    assert check_2, 'from_ase_atoms did not place tensors on the correct device'
コード例 #6
0
def test_geometry_from_ase_atoms_batch(device):
    """Check batch instances can be instantiated from ase.Atoms objects."""

    # Create an ase.Atoms object
    atoms = [molecule('CH4'), molecule('H2O')]
    ref_pos = pack([torch.tensor(i.positions) for i in atoms]).sft()
    ref_pos = ref_pos * length_units['angstrom']

    # Check 1: Ensure that the from_ase_atoms method correctly constructs
    # a geometry instance. This includes the unit conversion operation.
    geom_1 = Geometry.from_ase_atoms(atoms, device=device)
    check_1 = np.allclose(geom_1.positions.sft(), ref_pos),

    assert check_1, 'from_ase_atoms did not correctly parse the positions'

    # Check 2: Check the tensors were placed on the correct device
    check_2 = (geom_1.positions.device == device
               and geom_1.atomic_numbers.device == device)

    assert check_2, 'from_ase_atoms did not place tensors on the correct device'
コード例 #7
0
def test_geometry_distance_batch(device):
    """Geometry batch system distance test."""
    # Construct a geometry object
    geom = Geometry(atomic_numbers_data(device, True),
                    positions_data(device, True))
    geometry_distance_helper(geom)