Exemplo n.º 1
0
def test_collect_atom_triples_batch(four_atoms, ase_env):
    # Get the first environment (two atoms)
    ase_env.cutoff = 1.1
    nbh_1, offsets_1 = ase_env.get_environment(four_atoms)

    # Get the second environment (all_atoms)
    ase_env.cutoff = 3.0
    nbh_2, offsets_2 = ase_env.get_environment(four_atoms)

    # Pad to same size (assumes -1 padding)
    max_atoms = max(nbh_1.shape[0], nbh_2.shape[0])
    max_nbh = max(nbh_1.shape[1], nbh_2.shape[1])
    tmp_1 = -np.ones((max_atoms, max_nbh))
    tmp_2 = -np.ones((max_atoms, max_nbh))
    tmp_1[:nbh_1.shape[0], :nbh_1.shape[1]] = nbh_1
    tmp_2[:nbh_2.shape[0], :nbh_2.shape[1]] = nbh_2
    nbh_1 = tmp_1
    nbh_2 = tmp_2

    # Get masks and pair indices
    nbh_mask_1 = (nbh_1 >= 0).astype(np.int)
    nbh_mask_2 = (nbh_2 >= 0).astype(np.int)
    nbh_1_j, nbh_1_k, offset_idx_1_j, offset_idx_1_k = env.collect_atom_triples(
        nbh_1)
    nbh_2_j, nbh_2_k, offset_idx_2_j, offset_idx_2_k = env.collect_atom_triples(
        nbh_2)

    # Get pairwise masks
    mask_1_j = np.take_along_axis(nbh_mask_1, offset_idx_1_j, axis=1)
    mask_1_k = np.take_along_axis(nbh_mask_1, offset_idx_1_k, axis=1)
    mask_1_jk = mask_1_j * mask_1_k
    mask_2_j = np.take_along_axis(nbh_mask_2, offset_idx_2_j, axis=1)
    mask_2_k = np.take_along_axis(nbh_mask_2, offset_idx_2_k, axis=1)
    mask_2_jk = mask_2_j * mask_2_k

    # Generate batches and convert to torch
    batch_nbh = torch.LongTensor(np.array([nbh_1, nbh_2]))
    batch_nbh_mask = torch.LongTensor(np.array([nbh_mask_1, nbh_mask_2]))
    batch_nbh_j = torch.LongTensor(np.array([nbh_1_j, nbh_2_j]))
    batch_nbh_k = torch.LongTensor(np.array([nbh_1_k, nbh_2_k]))
    batch_offset_idx_j = torch.LongTensor(
        np.array([offset_idx_1_j, offset_idx_2_j]))
    batch_offset_idx_k = torch.LongTensor(
        np.array([offset_idx_1_k, offset_idx_2_k]))
    batch_mask_jk = torch.LongTensor(np.array([mask_1_jk, mask_2_jk]))

    # Collect triples via batch method
    (
        nbh_j,
        nbh_k,
        offset_idx_j,
        offset_idx_k,
        pair_mask,
    ) = env.collect_atom_triples_batch(batch_nbh, batch_nbh_mask)

    assert np.allclose(batch_nbh_j, nbh_j)
    assert np.allclose(batch_nbh_k, nbh_k)
    assert np.allclose(batch_offset_idx_j, offset_idx_j)
    assert np.allclose(batch_offset_idx_k, offset_idx_k)
    assert np.allclose(batch_mask_jk, pair_mask)
Exemplo n.º 2
0
def test_collect_atom_triples(four_atoms, ase_env):
    # Get the environment
    nbh, offsets = ase_env.get_environment(four_atoms)

    # Generate general indices
    n_atoms, n_neighbors = nbh.shape

    idx_j = []
    idx_k = []
    for k in range(n_neighbors):
        for j in range(k + 1, n_neighbors):
            idx_j.append(j)
            idx_k.append(k)

    idx_j = np.array(idx_j)
    idx_k = np.array(idx_k)

    # Generate ase pair neighborhoods
    ase_nbh_j = nbh[:, idx_j]
    ase_nbh_k = nbh[:, idx_k]

    # Set up offset indices
    ase_off_idx_j = np.repeat(idx_j[None, :], n_atoms, axis=0)
    ase_off_idx_k = np.repeat(idx_k[None, :], n_atoms, axis=0)

    nbh_j, nbh_k, offset_idx_j, offset_idx_k = env.collect_atom_triples(nbh)

    assert np.allclose(ase_nbh_j, nbh_j)
    assert np.allclose(ase_nbh_k, nbh_k)
    assert np.allclose(ase_off_idx_j, offset_idx_j)
    assert np.allclose(ase_off_idx_k, offset_idx_k)
Exemplo n.º 3
0
    def __getitem__(self, idx):
        at, properties = self.get_properties(idx)

        # get atom environment
        nbh_idx, offsets = self.environment_provider.get_environment(at)

        properties[Structure.neighbors] = torch.LongTensor(
            nbh_idx.astype(np.int))

        properties[Structure.cell_offset] = torch.FloatTensor(
            offsets.astype(np.float32))
        properties["_idx"] = torch.LongTensor(np.array([idx], dtype=np.int))

        if self.collect_triples:
            nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(
                nbh_idx)
            properties[Structure.neighbor_pairs_j] = torch.LongTensor(
                nbh_idx_j.astype(np.int))
            properties[Structure.neighbor_pairs_k] = torch.LongTensor(
                nbh_idx_k.astype(np.int))
            properties['offset_idx_j'] = torch.LongTensor(
                offset_idx_j.astype(np.int))
            properties['offset_idx_k'] = torch.LongTensor(
                offset_idx_k.astype(np.int))
        return properties
Exemplo n.º 4
0
def _convert_atoms(
    atoms,
    environment_provider=SimpleEnvironmentProvider(),
    collect_triples=False,
    center_positions=False,
    output=None,
):
    """
        Helper function to convert ASE atoms object to SchNetPack input format.

        Args:
            atoms (ase.Atoms): Atoms object of molecule
            environment_provider (callable): Neighbor list provider.
            device (str): Device for computation (default='cpu')
            output (dict): Destination for converted atoms, if not None

    Returns:
        dict of torch.Tensor: Properties including neighbor lists and masks
            reformated into SchNetPack input format.
    """
    if output is None:
        inputs = {}
    else:
        inputs = output

    # Elemental composition
    cell = np.array(atoms.cell.array, dtype=np.float32)  # get cell array

    inputs[Properties.Z] = torch.LongTensor(atoms.numbers.astype(np.int))
    positions = atoms.positions.astype(np.float32)
    if center_positions:
        positions -= atoms.get_center_of_mass()
    inputs[Properties.R] = torch.FloatTensor(positions)
    inputs[Properties.cell] = torch.FloatTensor(cell)

    # get atom environment
    nbh_idx, offsets = environment_provider.get_environment(atoms)

    # Get neighbors and neighbor mask
    inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int))

    # Get cells
    inputs[Properties.cell] = torch.FloatTensor(cell)
    inputs[Properties.cell_offset] = torch.FloatTensor(offsets.astype(np.float32))

    # If requested get neighbor lists for triples
    if collect_triples:
        nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(nbh_idx)
        inputs[Properties.neighbor_pairs_j] = torch.LongTensor(nbh_idx_j.astype(np.int))
        inputs[Properties.neighbor_pairs_k] = torch.LongTensor(nbh_idx_k.astype(np.int))

        inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
            offset_idx_j.astype(np.int)
        )
        inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
            offset_idx_k.astype(np.int)
        )

    return inputs
Exemplo n.º 5
0
def _convert_atoms(
    atoms,
    environment_provider=SimpleEnvironmentProvider(),
    collect_triples=False,
    centering_function=None,
    output=None,
):
    """
    Helper function to convert ASE atoms object to SchNetPack input format.

    Args:
        atoms (ase.Atoms): Atoms object of molecule
        environment_provider (callable): Neighbor list provider.
        collect_triples (bool, optional): Set to True if angular features are needed.
        centering_function (callable or None): Function for calculating center of
            molecule (center of mass/geometry/...). Center will be subtracted from
            positions.
        output (dict): Destination for converted atoms, if not None

    Returns:
        dict of torch.Tensor: Properties including neighbor lists and masks
            reformated into SchNetPack input format.

    """
    if output is None:
        inputs = {}
    else:
        inputs = output

    # Elemental composition
    inputs[Properties.Z] = atoms.numbers.astype(np.int)
    positions = atoms.positions.astype(np.float32)
    if centering_function:
        positions -= centering_function(atoms)
    inputs[Properties.R] = positions

    # get atom environment
    nbh_idx, offsets = environment_provider.get_environment(atoms)

    # Get neighbors and neighbor mask
    inputs[Properties.neighbors] = nbh_idx.astype(np.int)

    # Get cells
    inputs[Properties.cell] = np.array(atoms.cell.array, dtype=np.float32)
    inputs[Properties.cell_offset] = offsets.astype(np.float32)

    # If requested get neighbor lists for triples
    if collect_triples:
        nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(
            nbh_idx)
        inputs[Properties.neighbor_pairs_j] = nbh_idx_j.astype(np.int)
        inputs[Properties.neighbor_pairs_k] = nbh_idx_k.astype(np.int)

        inputs[Properties.neighbor_offsets_j] = offset_idx_j.astype(np.int)
        inputs[Properties.neighbor_offsets_k] = offset_idx_k.astype(np.int)

    return inputs
Exemplo n.º 6
0
    def convert_atoms(self, atoms):
        """
        Args:
            atoms (ase.Atoms): Atoms object of molecule

        Returns:
            dict of torch.Tensor: Properties including neighbor lists and masks reformated into SchNetPack
                input format.
        """
        inputs = {}
        idx = 0

        # Elemental composition
        inputs[Structure.Z] = torch.LongTensor(atoms.numbers.astype(np.int))
        inputs[Structure.atom_mask] = torch.ones_like(
            inputs[Structure.Z]).float()

        # Set positions
        positions = atoms.positions.astype(np.float32)
        inputs[Structure.R] = torch.FloatTensor(positions)

        # get atom environment
        nbh_idx, offsets = self.environment_provider.get_environment(
            idx, atoms)

        # Get neighbors and neighbor mask
        mask = torch.FloatTensor(nbh_idx) >= 0
        inputs[Structure.neighbor_mask] = mask.float()
        inputs[Structure.neighbors] = torch.LongTensor(nbh_idx.astype(
            np.int)) * mask.long()

        # Get cells
        inputs[Structure.cell] = torch.FloatTensor(
            atoms.cell.astype(np.float32))
        inputs[Structure.cell_offset] = torch.FloatTensor(
            offsets.astype(np.float32))

        # Set index
        inputs['_idx'] = torch.LongTensor(np.array([idx], dtype=np.int))

        # If requested get masks and neighbor lists for neighbor pairs
        if self.collect_triples is not None:
            nbh_idx_j, nbh_idx_k = collect_atom_triples(nbh_idx)
            inputs[Structure.neighbor_pairs_j] = torch.LongTensor(
                nbh_idx_j.astype(np.int))
            inputs[Structure.neighbor_pairs_k] = torch.LongTensor(
                nbh_idx_k.astype(np.int))
            inputs[Structure.neighbor_pairs_mask] = torch.ones_like(
                inputs[Structure.neighbor_pairs_j]).float()

        # Add batch dimension and move to CPU/GPU
        for key, value in inputs.items():
            inputs[key] = value.unsqueeze(0).to(self.device)

        return inputs
Exemplo n.º 7
0
    def _sharc2schnet(self, sharc_output):
        # Update internal structure with new Shark positions
        self.molecule.positions = np.array(sharc_output)

        schnet_inputs = dict()
        # Elemental composition
        schnet_inputs[Properties.Z] = torch.LongTensor(
            self.molecule.numbers.astype(np.int))
        schnet_inputs[Properties.atom_mask] = torch.ones_like(
            schnet_inputs[Properties.Z]).float()
        # Set positions
        positions = self.molecule.positions.astype(np.float32)
        schnet_inputs[Properties.R] = torch.FloatTensor(positions)

        # get atom environment
        nbh_idx, offsets = self.environment_provider.get_environment(
            self.molecule)
        # Get neighbors and neighbor mask
        mask = torch.FloatTensor(nbh_idx) >= 0
        schnet_inputs[Properties.neighbor_mask] = mask.float()
        schnet_inputs[Properties.neighbors] = torch.LongTensor(
            nbh_idx.astype(np.int)) * mask.long()
        # Get cells
        schnet_inputs[Properties.cell] = torch.FloatTensor(
            self.molecule.cell.astype(np.float32))
        schnet_inputs[Properties.cell_offset] = torch.FloatTensor(
            offsets.astype(np.float32))
        # If requested get masks and neighbor lists for neighbor pairs
        if self.collect_triples is not None:
            nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(
                nbh_idx)
            schnet_inputs[Properties.neighbor_pairs_j] = torch.LongTensor(
                nbh_idx_j.astype(np.int))
            schnet_inputs[Properties.neighbor_pairs_k] = torch.LongTensor(
                nbh_idx_k.astype(np.int))
            schnet_inputs[Properties.neighbor_pairs_mask] = torch.ones_like(
                schnet_inputs[Properties.neighbor_pairs_j]).float()
        # Add batch dimension and move to CPU/GPU
        for key, value in schnet_inputs.items():
            schnet_inputs[key] = value.unsqueeze(0).to(self.device)

        return schnet_inputs
Exemplo n.º 8
0
def _convert_atoms(
    atoms,
    environment_provider=SimpleEnvironmentProvider(),
    collect_triples=False,
    centering_function=get_center_of_mass,
    output=None,
    res_list=None,
):
    """
        Helper function to convert ASE atoms object to SchNetPack input format.

        Args:
            atoms (ase.Atoms): Atoms object of molecule
            environment_provider (callable): Neighbor list provider.
            collect_triples (bool, optional): Set to True if angular features are needed.
            centering_function (callable or None): Function for calculating center of
                molecule (center of mass/geometry/...). Center will be subtracted from
                positions.
            output (dict): Destination for converted atoms, if not None

    Returns:
        dict of torch.Tensor: Properties including neighbor lists and masks
            reformated into SchNetPack input format.
    """
    if output is None:
        inputs = {}
    else:
        inputs = output

    # Elemental composition
    cell = np.array(atoms.cell.array, dtype=np.float32)  # get cell array

    inputs[Properties.Z] = torch.LongTensor(atoms.numbers.astype(np.int))
    positions = atoms.positions.astype(np.float32)
    if centering_function:
        positions -= centering_function(atoms)
    inputs[Properties.R] = torch.FloatTensor(positions)
    inputs[Properties.cell] = torch.FloatTensor(cell)

    if "AP" not in type(environment_provider).__name__:

        # get atom environment
        nbh_idx, offsets = environment_provider.get_environment(atoms)

        # Get neighbors and neighbor mask
        inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int))

        # Get cells
        inputs[Properties.cell_offset] = torch.FloatTensor(
            offsets.astype(np.float32))

        # If requested get neighbor lists for triples
        if collect_triples:
            # Construct possible permutations
            nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(
                nbh_idx)

            inputs[Properties.neighbor_pairs_j] = torch.LongTensor(
                nbh_idx_j.astype(np.int))
            inputs[Properties.neighbor_pairs_k] = torch.LongTensor(
                nbh_idx_k.astype(np.int))

            inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
                offset_idx_j.astype(np.int))
            inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
                offset_idx_k.astype(np.int))

    elif type(environment_provider).__name__ == "APModEnvironmentProvider":
        # get atom environment
        nbh_idx, offsets = environment_provider.get_environment(atoms)

        # Get neighbors and neighbor mask
        inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int))

        # Get cells
        inputs[Properties.cell_offset] = torch.FloatTensor(
            offsets.astype(np.float32))
        ZB = atoms.numbers.astype(np.int)
        natoms, nneigh = nbh_idx.shape
        ZB = np.tile(ZB[np.newaxis], (natoms, 1))

        ZB = ZB[~np.eye(natoms, dtype=np.bool)].reshape(natoms, natoms - 1)

        inputs["ZB"] = torch.LongTensor(ZB)

        # If requested get neighbor lists for triples
        if collect_triples:

            # Construct possible permutations
            nbh_idx_j = np.tile(nbh_idx, nneigh)
            nbh_idx_k = np.repeat(nbh_idx, nneigh).reshape((natoms, -1))

            nbh_idx_j_tmp = nbh_idx_j[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))
            nbh_idx_k_tmp = nbh_idx_k[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))

            # Keep track of periodic images
            offset_idx = np.tile(np.arange(nneigh), (natoms, 1))

            # Construct indices for pairs of offsets
            offset_idx_j = np.tile(offset_idx, nneigh)
            offset_idx_k = np.repeat(offset_idx, nneigh).reshape((natoms, -1))

            offset_idx_j_tmp = offset_idx_j[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))
            offset_idx_k_tmp = offset_idx_k[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))

            inputs[Properties.neighbor_pairs_j] = torch.LongTensor(
                nbh_idx_j_tmp.astype(np.int))
            inputs[Properties.neighbor_pairs_k] = torch.LongTensor(
                nbh_idx_k_tmp.astype(np.int))

            inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
                offset_idx_j_tmp.astype(np.int))
            inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
                offset_idx_k_tmp.astype(np.int))
        """
        # If requested get neighbor lists for triples
        if collect_triples:
            # Construct possible permutations
            nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(nbh_idx)

            inputs[Properties.neighbor_pairs_j] = torch.LongTensor(nbh_idx_j.astype(np.int))
            inputs[Properties.neighbor_pairs_k] = torch.LongTensor(nbh_idx_k.astype(np.int))

            inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
                offset_idx_j.astype(np.int)
            )
            inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
                offset_idx_k.astype(np.int)
            )
        """
    elif type(environment_provider).__name__ == "APModPBCEnvironmentProvider":
        # get atom environment
        nbh_idx, offsets = environment_provider.get_environment(atoms)

        # Get neighbors and neighbor mask
        inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int))

        neighbor_test = inputs[Properties.neighbors].unsqueeze(0)
        pos_test = inputs[Properties.R].unsqueeze(0)
        cell_test = inputs[Properties.cell].unsqueeze(0)

        # Get cells
        inputs[Properties.cell_offset] = torch.FloatTensor(
            offsets.astype(np.float32))
        offset_test = torch.FloatTensor(offsets.astype(
            np.float32)).unsqueeze(0)

        distances, displacement = atom_distances(pos_test,
                                                 neighbor_test,
                                                 cell_test,
                                                 offset_test,
                                                 return_vecs=True)
        displacement = displacement.squeeze(0)

        #only works for orthorhombic box
        box_shift = -torch.round(displacement / cell_test[0, 0, 0])
        inputs[Properties.cell_offset] = torch.FloatTensor(box_shift)

        ZB = atoms.numbers.astype(np.int)
        natoms, nneigh = nbh_idx.shape
        ZB = np.tile(ZB[np.newaxis], (natoms, 1))

        ZB = ZB[~np.eye(natoms, dtype=np.bool)].reshape(natoms, natoms - 1)

        inputs["ZB"] = torch.LongTensor(ZB)

        # If requested get neighbor lists for triples
        if collect_triples:

            # Construct possible permutations
            nbh_idx_j = np.tile(nbh_idx, nneigh)
            nbh_idx_k = np.repeat(nbh_idx, nneigh).reshape((natoms, -1))

            nbh_idx_j_tmp = nbh_idx_j[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))
            nbh_idx_k_tmp = nbh_idx_k[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))

            # Keep track of periodic images
            offset_idx = np.tile(np.arange(nneigh), (natoms, 1))

            # Construct indices for pairs of offsets
            offset_idx_j = np.tile(offset_idx, nneigh)
            offset_idx_k = np.repeat(offset_idx, nneigh).reshape((natoms, -1))

            offset_idx_j_tmp = offset_idx_j[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))
            offset_idx_k_tmp = offset_idx_k[nbh_idx_j != nbh_idx_k].reshape(
                natoms, nneigh * (nneigh - 1))

            inputs[Properties.neighbor_pairs_j] = torch.LongTensor(
                nbh_idx_j_tmp.astype(np.int))
            inputs[Properties.neighbor_pairs_k] = torch.LongTensor(
                nbh_idx_k_tmp.astype(np.int))

            inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
                offset_idx_j_tmp.astype(np.int))
            inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
                offset_idx_k_tmp.astype(np.int))

    elif type(environment_provider).__name__ == "APNetPBCEnvironmentProvider":
        if res_list is not None:
            monA = len(res_list[0])
            monB = len(res_list[1])
            inputs['ZA'] = torch.LongTensor(atoms.numbers[0:monA].astype(
                np.int))
            inputs['ZB'] = torch.LongTensor(atoms.numbers[monA:monA +
                                                          monB].astype(np.int))

        inputs['ZA'], inputs['ZB'] = inputs['ZA'].long(), inputs['ZB'].long()

        # get atom environment
        nbh_idx_intra, offset_intra, nbh_idx_inter, offsets_inter = environment_provider.get_environment(
            atoms, inputs)

        # Get neighbors and neighbor mask
        inputs[Properties.neighbor_inter] = torch.LongTensor(
            nbh_idx_inter.astype(np.int))

        mask = inputs[Properties.neighbor_inter] >= 0
        inputs[Properties.neighbor_inter_mask] = mask.float()
        inputs[Properties.neighbor_inter] = (
            inputs[Properties.neighbor_inter] *
            inputs[Properties.neighbor_inter_mask].long())

        neighbor_test = inputs[Properties.neighbor_inter].unsqueeze(0)
        pos_test = inputs[Properties.R].unsqueeze(0)
        cell_test = inputs[Properties.cell].unsqueeze(0)

        # Get cells
        offset_test = torch.FloatTensor(offsets_inter.astype(
            np.float32)).unsqueeze(0)

        distances, displacement = atom_distances(pos_test,
                                                 neighbor_test,
                                                 cell_test,
                                                 offset_test,
                                                 return_vecs=True)
        displacement = displacement.squeeze(0)

        #only works for orthorhombic box
        box_shift = -torch.round(displacement / cell_test[0, 0, 0])
        distances = atom_distances(pos_test, neighbor_test, cell_test,
                                   box_shift.unsqueeze(0))
        inputs[Properties.neighbor_offset_inter] = torch.FloatTensor(box_shift)

        natoms, nneigh = nbh_idx_inter.shape
        nbh_idx_k = np.tile(nbh_idx_intra, nneigh)
        nbh_idx_j = np.repeat(nbh_idx_inter, nneigh).reshape((natoms, -1))

        offset_idx = np.tile(np.arange(nneigh), (natoms, 1))
        offset_idx_k = np.tile(offset_idx, nneigh)
        offset_idx_j = np.repeat(offset_idx, nneigh).reshape((natoms, -1))

        inputs[Properties.neighbor_pairs_j] = torch.LongTensor(
            nbh_idx_j.astype(np.int))
        inputs[Properties.neighbor_pairs_k] = torch.LongTensor(
            nbh_idx_k.astype(np.int))

        inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
            offset_idx_j.astype(np.int))
        inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
            offset_idx_k.astype(np.int))

        mask_triples = np.ones_like(
            inputs[Properties.neighbor_pairs_j].numpy())
        mask_triples[inputs[Properties.neighbor_pairs_j].numpy() < 0] = 0
        mask_triples[inputs[Properties.neighbor_pairs_k].numpy() < 0] = 0

        neighbor_test = torch.LongTensor(nbh_idx_intra.astype(
            np.int)).unsqueeze(0)
        pos_test = inputs[Properties.R].unsqueeze(0)
        cell_test = inputs[Properties.cell].unsqueeze(0)

        # Get cells
        offset_test = torch.FloatTensor(offset_intra.astype(
            np.float32)).unsqueeze(0)

        distances, displacement = atom_distances(pos_test,
                                                 neighbor_test,
                                                 cell_test,
                                                 offset_test,
                                                 return_vecs=True)
        displacement = displacement.squeeze(0)

        #only works for orthorhombic box
        box_shift = -torch.round(displacement / cell_test[0, 0, 0])
        inputs[Properties.cell_offset_intra] = torch.FloatTensor(box_shift)

        mask_self = np.repeat(np.arange(0, nbh_idx_k.shape[0]),
                              nbh_idx_k.shape[1]).reshape(
                                  nbh_idx_k.shape[0], nbh_idx_k.shape[1])
        mask_triples[mask_self == nbh_idx_k] = 0
        inputs[Properties.neighbor_pairs_mask] = torch.LongTensor(
            mask_triples.astype(np.float))

        mask_self = np.repeat(np.arange(0, nbh_idx_intra.shape[0]),
                              nbh_idx_intra.shape[1]).reshape(
                                  nbh_idx_intra.shape[0],
                                  nbh_idx_intra.shape[1])
        neighborhood_idx = nbh_idx_intra[mask_self != nbh_idx_intra].reshape(
            nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1)
        inputs[Properties.neighbors] = torch.LongTensor(
            neighborhood_idx.astype(np.int))

        box_shift = box_shift[mask_self != nbh_idx_intra, :].reshape(
            nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1, 3)
        inputs[Properties.cell_offset] = torch.FloatTensor(box_shift)

    else:

        if res_list is not None:
            monA = len(res_list[0])
            monB = len(res_list[1])
            inputs['ZA'] = torch.LongTensor(atoms.numbers[0:monA].astype(
                np.int))
            inputs['ZB'] = torch.LongTensor(atoms.numbers[monA:monA +
                                                          monB].astype(np.int))

        inputs['ZA'], inputs['ZB'] = inputs['ZA'].long(), inputs['ZB'].long()

        # get atom environment
        nbh_idx_intra, offset_intra, nbh_idx_inter, offsets_inter = environment_provider.get_environment(
            atoms, inputs)
        nbh_idx_test, offset_test = SimpleEnvironmentProvider(
        ).get_environment(atoms)
        #print(nbh_idx_test[0, :])
        #sys.exit()
        # Get neighbors and neighbor mask
        inputs[Properties.neighbor_inter] = torch.LongTensor(
            nbh_idx_inter.astype(np.int))

        mask = inputs[Properties.neighbor_inter] >= 0
        inputs[Properties.neighbor_inter_mask] = mask.float()
        inputs[Properties.neighbor_inter] = (
            inputs[Properties.neighbor_inter] *
            inputs[Properties.neighbor_inter_mask].long())

        # Get cells
        inputs[Properties.neighbor_offset_inter] = torch.FloatTensor(
            offsets_inter.astype(np.float32))

        natoms, nneigh = nbh_idx_inter.shape
        nbh_idx_k = np.tile(nbh_idx_intra, nneigh)
        nbh_idx_j = np.repeat(nbh_idx_inter, nneigh).reshape((natoms, -1))

        offset_idx = np.tile(np.arange(nneigh), (natoms, 1))
        offset_idx_k = np.tile(offset_idx, nneigh)
        offset_idx_j = np.repeat(offset_idx, nneigh).reshape((natoms, -1))

        inputs[Properties.neighbor_pairs_j] = torch.LongTensor(
            nbh_idx_j.astype(np.int))
        inputs[Properties.neighbor_pairs_k] = torch.LongTensor(
            nbh_idx_k.astype(np.int))

        inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
            offset_idx_j.astype(np.int))
        inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
            offset_idx_k.astype(np.int))

        mask_triples = np.ones_like(
            inputs[Properties.neighbor_pairs_j].numpy())
        mask_triples[inputs[Properties.neighbor_pairs_j].numpy() < 0] = 0
        mask_triples[inputs[Properties.neighbor_pairs_k].numpy() < 0] = 0

        mask_self = np.repeat(np.arange(0, nbh_idx_k.shape[0]),
                              nbh_idx_k.shape[1]).reshape(
                                  nbh_idx_k.shape[0], nbh_idx_k.shape[1])
        mask_triples[mask_self == nbh_idx_k] = 0
        inputs[Properties.neighbor_pairs_mask] = torch.LongTensor(
            mask_triples.astype(np.float))

        mask_self = np.repeat(np.arange(0, nbh_idx_intra.shape[0]),
                              nbh_idx_intra.shape[1]).reshape(
                                  nbh_idx_intra.shape[0],
                                  nbh_idx_intra.shape[1])
        neighborhood_idx = nbh_idx_intra[mask_self != nbh_idx_intra].reshape(
            nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1)
        inputs[Properties.neighbors] = torch.LongTensor(
            neighborhood_idx.astype(np.int))

        inputs[Properties.cell_offset_intra] = torch.FloatTensor(
            offset_intra.astype(np.float32))

        offset_intra = offset_intra[mask_self != nbh_idx_intra, :].reshape(
            nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1, 3)
        inputs[Properties.cell_offset] = torch.FloatTensor(
            offset_intra.astype(np.float32))
    return inputs
Exemplo n.º 9
0
def _convert_atoms(
    atoms,
    environment_provider=SimpleEnvironmentProvider(),
    collect_triples=False,
    centering_function=None,
    output=None,
    deleteIntraatomicInteraction=False,
    intraAtomicIdent=None,
):
    """
        Helper function to convert ASE atoms object to SchNetPack input format.

        Args:
            atoms (ase.Atoms): Atoms object of molecule
            environment_provider (callable): Neighbor list provider.
            collect_triples (bool, optional): Set to True if angular features are needed.
            centering_function (callable or None): Function for calculating center of
                molecule (center of mass/geometry/...). Center will be subtracted from
                positions.
            output (dict): Destination for converted atoms, if not None

    Returns:
        dict of torch.Tensor: Properties including neighbor lists and masks
            reformated into SchNetPack input format.
    """
    if output is None:
        inputs = {}
    else:
        inputs = output

    # Elemental composition
    cell = np.array(atoms.cell.array, dtype=np.float32)  # get cell array

    inputs[Properties.Z] = torch.LongTensor(atoms.numbers.astype(np.int))
    positions = atoms.positions.astype(np.float32)
    if centering_function:
        positions -= centering_function(atoms)
    inputs[Properties.R] = torch.FloatTensor(positions)
    inputs[Properties.cell] = torch.FloatTensor(cell)

    # get atom environment
    nbh_idx, offsets = environment_provider.get_environment(atoms)

    if deleteIntraatomicInteraction:
        if intraAtomicIdent is None:
            intraAtomicIdent = inputs['props'][:, :, 7:8]
        nl = nbh_idx.astype(np.int).copy()
        ligand_indicator = intraAtomicIdent

        for i in range(len(nl)):
            for j in range(len(nl[i])):
                atom_1 = i
                atom_2 = nl[i][j]
                if ligand_indicator[atom_1] != ligand_indicator[atom_2]:
                    nl[i][j] = -1
        inputs[Properties.neighbors] = torch.LongTensor(nl.astype(np.int))
    else:
        inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int))

    # Get neighbors and neighbor mask

    # Get cells
    inputs[Properties.cell] = torch.FloatTensor(cell)
    inputs[Properties.cell_offset] = torch.FloatTensor(
        offsets.astype(np.float32))

    # If requested get neighbor lists for triples
    if collect_triples:
        nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(
            nbh_idx)
        inputs[Properties.neighbor_pairs_j] = torch.LongTensor(
            nbh_idx_j.astype(np.int))
        inputs[Properties.neighbor_pairs_k] = torch.LongTensor(
            nbh_idx_k.astype(np.int))

        inputs[Properties.neighbor_offsets_j] = torch.LongTensor(
            offset_idx_j.astype(np.int))
        inputs[Properties.neighbor_offsets_k] = torch.LongTensor(
            offset_idx_k.astype(np.int))

    return inputs