コード例 #1
0
ファイル: model_utils.py プロジェクト: zizai/graph-nvp
def get_latent_vec(model, mol_smiles, data_name='qm9'):
    out_size = 9
    transform_fn = transform_qm9.transform_fn

    if data_name == 'zinc250k':
        out_size = 38
        transform_fn = transform_fn_zinc250k

    preprocessor = GGNNPreprocessor(out_size=out_size, kekulize=True)
    atoms, adj = preprocessor.get_input_features(
        Chem.MolFromSmiles(mol_smiles))
    atoms, adj, _ = transform_fn((atoms, adj, None))
    adj = np.expand_dims(adj, axis=0)
    atoms = np.expand_dims(atoms, axis=0)
    with chainer.no_backprop_mode():
        z = model(adj, atoms)
    z = np.hstack([z[0][0].data, z[0][1].data]).squeeze(0)
    return z
コード例 #2
0
def set_up_preprocessor(method, max_atoms):
    preprocessor = None

    if method == 'nfp':
        preprocessor = NFPPreprocessor(max_atoms=max_atoms)
    elif method == 'ggnn':
        preprocessor = GGNNPreprocessor(max_atoms=max_atoms)
    elif method == 'schnet':
        preprocessor = SchNetPreprocessor(max_atoms=max_atoms)
    elif method == 'weavenet':
        preprocessor = WeaveNetPreprocessor(max_atoms=max_atoms)
    elif method == 'rsgcn':
        preprocessor = RSGCNPreprocessor(max_atoms=max_atoms)
    else:
        raise ValueError('[ERROR] Invalid method: {}'.format(method))
    return preprocessor
コード例 #3
0
def test_ggnn_preprocessor_kekulize():
    preprocessor = GGNNPreprocessor(kekulize=True)
    dataset = SmilesParser(preprocessor).parse(
        ['C#N', 'Cc1cnc(C=O)n1C', 'c1ccccc1'])["dataset"]
    atoms1, adjs1 = dataset[1]
    assert numpy.allclose(
        atoms1, numpy.array([6, 6, 6, 7, 6, 6, 8, 7, 6], numpy.int32))
    # NOT include aromatic bond (ch=3)
    expect_adjs = numpy.array([[[0., 1., 0., 0., 0., 0., 0., 0., 0.],
                                [1., 0., 0., 0., 0., 0., 0., 1., 0.],
                                [0., 0., 0., 1., 0., 0., 0., 0., 0.],
                                [0., 0., 1., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 1., 0., 1., 0.],
                                [0., 0., 0., 0., 1., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 1., 0., 0., 1., 0., 0., 0., 1.],
                                [0., 0., 0., 0., 0., 0., 0., 1., 0.]],
                               [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 1., 0., 0., 0., 0., 0., 0.],
                                [0., 1., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 1., 0., 0., 0., 0.],
                                [0., 0., 0., 1., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 1., 0., 0.],
                                [0., 0., 0., 0., 0., 1., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
                               [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
                               [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
                              dtype=numpy.float32)
    assert numpy.allclose(adjs1, expect_adjs)
コード例 #4
0
def test_ggnn_preprocessor():
    preprocessor = GGNNPreprocessor()

    def postprocess_label(label_list):
        # Set -1 to the place where the label is not found,
        # this corresponds to not calculate loss with `sigmoid_cross_entropy`
        return [-1 if label is None else label for label in label_list]

    dataset = SDFFileParser(preprocessor,
                            postprocess_label=postprocess_label).parse(
                                get_tox21_filepath('train'))["dataset"]

    index = numpy.random.choice(len(dataset), None)
    atoms, adjs = dataset[index]

    assert atoms.ndim == 1  # (atom, )
    assert atoms.dtype == numpy.int32
    # (edge_type, atom from, atom to)
    assert adjs.ndim == 3
    assert adjs.dtype == numpy.float32
コード例 #5
0
ファイル: download_data.py プロジェクト: zizai/graph-nvp
data_name = args.data_name
data_type = args.data_type
print('args', vars(args))

if data_name == 'qm9':
    max_atoms = 9
elif data_name == 'zinc250k':
    max_atoms = 38
else:
    raise ValueError("[ERROR] Unexpected value data_name={}".format(data_name))

if data_type == 'gcn':
    preprocessor = RSGCNPreprocessor(out_size=max_atoms)
elif data_type == 'relgcn':
    # preprocessor = GGNNPreprocessor(out_size=max_atoms, kekulize=True, return_is_real_node=False)
    preprocessor = GGNNPreprocessor(out_size=max_atoms, kekulize=True)
else:
    raise ValueError("[ERROR] Unexpected value data_type={}".format(data_type))

data_dir = "."
os.makedirs(data_dir, exist_ok=True)

if data_name == 'qm9':
    dataset = datasets.get_qm9(preprocessor)
elif data_name == 'zinc250k':
    dataset = datasets.get_zinc250k(preprocessor)
else:
    raise ValueError("[ERROR] Unexpected value data_name={}".format(data_name))

NumpyTupleDataset.save(
    os.path.join(data_dir,
コード例 #6
0
def test_nfp_preprocessor_assert_raises():
    with pytest.raises(ValueError):
        pp = GGNNPreprocessor(max_atoms=3, out_size=2)  # NOQA
コード例 #7
0
def test_ggnn_preprocessor():
    preprocessor = GGNNPreprocessor()
    dataset = SmilesParser(preprocessor).parse(
        ['C#N', 'Cc1cnc(C=O)n1C', 'c1ccccc1'])["dataset"]

    index = numpy.random.choice(len(dataset), None)
    atoms, adjs = dataset[index]
    assert atoms.ndim == 1  # (atom, )
    assert atoms.dtype == numpy.int32
    # (edge_type, atom from, atom to)
    assert adjs.ndim == 3
    assert adjs.dtype == numpy.float32

    atoms0, adjs0 = dataset[0]
    assert numpy.allclose(atoms0, numpy.array([6, 7], numpy.int32))
    expect_adjs = numpy.array([[[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]],
                               [[0., 1.], [1., 0.]], [[0., 0.], [0., 0.]]],
                              dtype=numpy.float32)
    assert numpy.allclose(adjs0, expect_adjs)

    atoms1, adjs1 = dataset[1]
    assert numpy.allclose(
        atoms1, numpy.array([6, 6, 6, 7, 6, 6, 8, 7, 6], numpy.int32))
    # include aromatic bond (ch=3)
    expect_adjs = numpy.array([[[0., 1., 0., 0., 0., 0., 0., 0., 0.],
                                [1., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 1., 0., 0., 0.],
                                [0., 0., 0., 0., 1., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 1.],
                                [0., 0., 0., 0., 0., 0., 0., 1., 0.]],
                               [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 1., 0., 0.],
                                [0., 0., 0., 0., 0., 1., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
                               [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
                               [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 1., 0., 0., 0., 0., 1., 0.],
                                [0., 1., 0., 1., 0., 0., 0., 0., 0.],
                                [0., 0., 1., 0., 1., 0., 0., 0., 0.],
                                [0., 0., 0., 1., 0., 0., 0., 1., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.],
                                [0., 1., 0., 0., 1., 0., 0., 0., 0.],
                                [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
                              dtype=numpy.float32)
    assert numpy.allclose(adjs1, expect_adjs)