def test_corruption_fake1(): corruption = CorruptionTransform(num_fake=1) molecule = MoleculeSample(np.array(['H', 'H', 'H', 'H', 'H']), np.eye(5), {'a': 1}, '') np.random.seed(0) new_sample = corruption(molecule) assert np.array_equal(new_sample.atoms, np.array(['H', 'H', 'H', 'H', 'O'])) assert np.array_equal(new_sample.atoms_num, np.array([1, 1, 1, 1, 3])) assert np.array_equal(new_sample.targets, np.array(['H'])) assert np.array_equal(new_sample.target_mask, np.array([0, 0, 0, 0, 1]))
def test_corruption_epsilon(): corruption = CorruptionTransform(num_masks=1, epsilon=1) molecule = MoleculeSample(np.array(['H', 'H', 'N', 'H', 'H']), np.eye(5), {'a': 1}, '') np.random.seed(3) new_sample = corruption(molecule) assert np.array_equal(new_sample.atoms, np.array(['H', 'M', 'M', 'H', 'H'])) assert np.array_equal(new_sample.atoms_num, np.array([1, 6, 6, 1, 1])) assert np.array_equal(new_sample.targets, np.array(['H', 'N'])) assert np.array_equal(new_sample.target_mask, np.array([0, 1, 1, 0, 0]))
def test_corruption_mask2_fake2(): corruption = CorruptionTransform(num_fake=2, num_masks=2) molecule = MoleculeSample(np.array(['H', 'N', 'N', 'H', 'H']), np.eye(5), {'a': 1}, '') np.random.seed(0) new_sample = corruption(molecule) assert np.array_equal(new_sample.atoms, np.array(['M', 'C', 'M', 'H', 'N'])) assert np.array_equal(new_sample.atoms_num, np.array([6, 2, 6, 1, 4])) assert np.array_equal(new_sample.targets, np.array(['H', 'N', 'N', 'H'])) assert np.array_equal(new_sample.target_mask, np.array([1, 1, 1, 0, 1]))
def test_corruption_override_mask(): corruption = CorruptionTransform(num_masks=1) molecule = MoleculeSample( np.array(['H', 'C', 'C', 'H', 'O', 'N', 'N', 'H']), np.eye(8), {'a': 1}, '') _ = corruption(molecule) assert np.array_equal(molecule.atoms, np.array(['H', 'C', 'C', 'H', 'O', 'N', 'N', 'H'])) assert np.array_equal(molecule.atoms_num, np.array([1, 2, 2, 1, 3, 4, 4, 1])) assert np.array_equal(molecule.adj, np.eye(8)) assert molecule.properties == {'a': 1}
def __init__(self, num_masks=1, num_fake=0, epsilon_greedy=0.0, num_classes=2, num_samples=1000, max_length=30, ambiguity=False, num_bondtypes=1): """Create a dataset of graphs from the QM9 data. Arg: data (string): file path to a pickle file num_masks (int): Number of atoms to mask in each molecule num_fake (int): Number of atoms to fake in each molecule epsilon_greedy (float): epsilon parameter in the epsilon-greedy scheme for selecting number of corrupted atoms """ self.num_masks = num_masks self.num_fake = num_fake self.epsilon_greedy = epsilon_greedy self.num_classes = num_classes self.num_samples = num_samples self.molecule_generator = MoleculeGenerator( num_classes=num_classes, max_length=max_length, ambiguity=ambiguity, num_bondtypes=num_bondtypes) self.corruption = CorruptionTransform(num_masks=num_masks, num_fake=num_fake, epsilon=epsilon_greedy) self.data = [] for i in range(self.num_samples): molecule = self.molecule_generator.generate_molecule() molecule = chem.AddHs(molecule) Adj = chem.rdmolops.GetAdjacencyMatrix(molecule) atoms = np.asarray([ periodic_table[atom.GetAtomicNum()] for atom in molecule.GetAtoms() ]) smiles = chem.MolToSmiles(molecule) self.data += [MoleculeSample(atoms, Adj, {}, smiles)]