def get_latent_vec(model, mol_smiles, data_name='qm9'): out_size = 9 transform_fn = transform_qm9.transform_fn if data_name == 'zinc250k': out_size = 38 transform_fn = transform_fn_zinc250k preprocessor = GGNNPreprocessor(out_size=out_size, kekulize=True) atoms, adj = preprocessor.get_input_features( Chem.MolFromSmiles(mol_smiles)) atoms, adj, _ = transform_fn((atoms, adj, None)) adj = np.expand_dims(adj, axis=0) atoms = np.expand_dims(atoms, axis=0) with chainer.no_backprop_mode(): z = model(adj, atoms) z = np.hstack([z[0][0].data, z[0][1].data]).squeeze(0) return z
def set_up_preprocessor(method, max_atoms): preprocessor = None if method == 'nfp': preprocessor = NFPPreprocessor(max_atoms=max_atoms) elif method == 'ggnn': preprocessor = GGNNPreprocessor(max_atoms=max_atoms) elif method == 'schnet': preprocessor = SchNetPreprocessor(max_atoms=max_atoms) elif method == 'weavenet': preprocessor = WeaveNetPreprocessor(max_atoms=max_atoms) elif method == 'rsgcn': preprocessor = RSGCNPreprocessor(max_atoms=max_atoms) else: raise ValueError('[ERROR] Invalid method: {}'.format(method)) return preprocessor
def test_ggnn_preprocessor_kekulize(): preprocessor = GGNNPreprocessor(kekulize=True) dataset = SmilesParser(preprocessor).parse( ['C#N', 'Cc1cnc(C=O)n1C', 'c1ccccc1'])["dataset"] atoms1, adjs1 = dataset[1] assert numpy.allclose( atoms1, numpy.array([6, 6, 6, 7, 6, 6, 8, 7, 6], numpy.int32)) # NOT include aromatic bond (ch=3) expect_adjs = numpy.array([[[0., 1., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 1., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 1., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 1., 0., 0., 0., 1.], [0., 0., 0., 0., 0., 0., 0., 1., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=numpy.float32) assert numpy.allclose(adjs1, expect_adjs)
def test_ggnn_preprocessor(): preprocessor = GGNNPreprocessor() def postprocess_label(label_list): # Set -1 to the place where the label is not found, # this corresponds to not calculate loss with `sigmoid_cross_entropy` return [-1 if label is None else label for label in label_list] dataset = SDFFileParser(preprocessor, postprocess_label=postprocess_label).parse( get_tox21_filepath('train'))["dataset"] index = numpy.random.choice(len(dataset), None) atoms, adjs = dataset[index] assert atoms.ndim == 1 # (atom, ) assert atoms.dtype == numpy.int32 # (edge_type, atom from, atom to) assert adjs.ndim == 3 assert adjs.dtype == numpy.float32
data_name = args.data_name data_type = args.data_type print('args', vars(args)) if data_name == 'qm9': max_atoms = 9 elif data_name == 'zinc250k': max_atoms = 38 else: raise ValueError("[ERROR] Unexpected value data_name={}".format(data_name)) if data_type == 'gcn': preprocessor = RSGCNPreprocessor(out_size=max_atoms) elif data_type == 'relgcn': # preprocessor = GGNNPreprocessor(out_size=max_atoms, kekulize=True, return_is_real_node=False) preprocessor = GGNNPreprocessor(out_size=max_atoms, kekulize=True) else: raise ValueError("[ERROR] Unexpected value data_type={}".format(data_type)) data_dir = "." os.makedirs(data_dir, exist_ok=True) if data_name == 'qm9': dataset = datasets.get_qm9(preprocessor) elif data_name == 'zinc250k': dataset = datasets.get_zinc250k(preprocessor) else: raise ValueError("[ERROR] Unexpected value data_name={}".format(data_name)) NumpyTupleDataset.save( os.path.join(data_dir,
def test_nfp_preprocessor_assert_raises(): with pytest.raises(ValueError): pp = GGNNPreprocessor(max_atoms=3, out_size=2) # NOQA
def test_ggnn_preprocessor(): preprocessor = GGNNPreprocessor() dataset = SmilesParser(preprocessor).parse( ['C#N', 'Cc1cnc(C=O)n1C', 'c1ccccc1'])["dataset"] index = numpy.random.choice(len(dataset), None) atoms, adjs = dataset[index] assert atoms.ndim == 1 # (atom, ) assert atoms.dtype == numpy.int32 # (edge_type, atom from, atom to) assert adjs.ndim == 3 assert adjs.dtype == numpy.float32 atoms0, adjs0 = dataset[0] assert numpy.allclose(atoms0, numpy.array([6, 7], numpy.int32)) expect_adjs = numpy.array([[[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]], [[0., 1.], [1., 0.]], [[0., 0.], [0., 0.]]], dtype=numpy.float32) assert numpy.allclose(adjs0, expect_adjs) atoms1, adjs1 = dataset[1] assert numpy.allclose( atoms1, numpy.array([6, 6, 6, 7, 6, 6, 8, 7, 6], numpy.int32)) # include aromatic bond (ch=3) expect_adjs = numpy.array([[[0., 1., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 1.], [0., 0., 0., 0., 0., 0., 0., 1., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 1., 0.], [0., 1., 0., 1., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 1., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=numpy.float32) assert numpy.allclose(adjs1, expect_adjs)