Exemplo n.º 1
0
def test_normalize_rotation():
    assert NormalizeRotation().__repr__() == 'NormalizeRotation()'

    pos = torch.Tensor([[-2, -2], [-1, -1], [0, 0], [1, 1], [2, 2]])
    norm = torch.Tensor([[-1, 1], [-1, 1], [-1, 1], [-1, 1], [-1, 1]])
    data = Data(pos=pos)
    data.norm = norm
    data = NormalizeRotation()(data)
    assert len(data) == 2

    expected_pos = torch.Tensor([
        [-2 * sqrt(2), 0],
        [-sqrt(2), 0],
        [0, 0],
        [sqrt(2), 0],
        [2 * sqrt(2), 0],
    ])
    expected_norm = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]

    assert torch.allclose(data.pos, expected_pos)
    assert data.norm.tolist() == expected_norm

    data = Data(pos=pos)
    data.norm = norm
    data = NormalizeRotation(max_points=3)(data)
    assert len(data) == 2

    assert torch.allclose(data.pos, expected_pos)
    assert data.norm.tolist() == expected_norm
Exemplo n.º 2
0
    def _obs(self) -> Tuple[Batch, List[List[int]]]:
        """
        returns
        -------
        Tuple[Batch, List[List[int]]
            The Batch object contains the Pytorch Geometric graph representing the molecule. The list of lists of integers
            is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers
            are the indices of the four atoms making up the torsion.
        """
        mol = Chem.rdmolops.RemoveHs(self.mol)
        conf = mol.GetConformer()
        atoms = mol.GetAtoms()
        bonds = mol.GetBonds()

        node_features = [molecule_features.atom_type_CO(atom) + molecule_features.atom_coords(atom, conf) for atom in atoms]
        edge_indices = molecule_features.get_bond_pairs(mol)
        edge_attributes = [molecule_features.bond_type(bond) for bond in bonds] * 2


        data = Data(
                    x=torch.tensor(node_features, dtype=torch.float),
                    edge_index=torch.tensor(edge_indices, dtype=torch.long),
                    edge_attr=torch.tensor(edge_attributes,dtype=torch.float),
                    pos=torch.Tensor(conf.GetPositions())
                )

        data = Center()(data)
        data = NormalizeRotation()(data)
        data.x[:,-3:] = data.pos
        data = Batch.from_data_list([data])
        return data, self.nonring