def test_normalize_rotation(): assert NormalizeRotation().__repr__() == 'NormalizeRotation()' pos = torch.Tensor([[-2, -2], [-1, -1], [0, 0], [1, 1], [2, 2]]) norm = torch.Tensor([[-1, 1], [-1, 1], [-1, 1], [-1, 1], [-1, 1]]) data = Data(pos=pos) data.norm = norm data = NormalizeRotation()(data) assert len(data) == 2 expected_pos = torch.Tensor([ [-2 * sqrt(2), 0], [-sqrt(2), 0], [0, 0], [sqrt(2), 0], [2 * sqrt(2), 0], ]) expected_norm = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] assert torch.allclose(data.pos, expected_pos) assert data.norm.tolist() == expected_norm data = Data(pos=pos) data.norm = norm data = NormalizeRotation(max_points=3)(data) assert len(data) == 2 assert torch.allclose(data.pos, expected_pos) assert data.norm.tolist() == expected_norm
def _obs(self) -> Tuple[Batch, List[List[int]]]: """ returns ------- Tuple[Batch, List[List[int]] The Batch object contains the Pytorch Geometric graph representing the molecule. The list of lists of integers is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers are the indices of the four atoms making up the torsion. """ mol = Chem.rdmolops.RemoveHs(self.mol) conf = mol.GetConformer() atoms = mol.GetAtoms() bonds = mol.GetBonds() node_features = [molecule_features.atom_type_CO(atom) + molecule_features.atom_coords(atom, conf) for atom in atoms] edge_indices = molecule_features.get_bond_pairs(mol) edge_attributes = [molecule_features.bond_type(bond) for bond in bonds] * 2 data = Data( x=torch.tensor(node_features, dtype=torch.float), edge_index=torch.tensor(edge_indices, dtype=torch.long), edge_attr=torch.tensor(edge_attributes,dtype=torch.float), pos=torch.Tensor(conf.GetPositions()) ) data = Center()(data) data = NormalizeRotation()(data) data.x[:,-3:] = data.pos data = Batch.from_data_list([data]) return data, self.nonring