コード例 #1
0
    def __init__(self,
                 smiles_to_graph=smiles_to_bigraph,
                 node_featurizer=None,
                 edge_featurizer=None,
                 load=False,
                 log_every=1000,
                 cache_file_path='./bbbp_dglgraph.bin',
                 n_jobs=1):

        self._url = 'dataset/bbbp.zip'
        data_path = get_download_dir() + '/bbbp.zip'
        dir_path = get_download_dir() + '/bbbp'
        download(_get_dgl_url(self._url), path=data_path, overwrite=False)
        extract_archive(data_path, dir_path)
        df = pd.read_csv(dir_path + '/BBBP.csv')

        super(BBBP, self).__init__(df=df,
                                   smiles_to_graph=smiles_to_graph,
                                   node_featurizer=node_featurizer,
                                   edge_featurizer=edge_featurizer,
                                   smiles_column='smiles',
                                   cache_file_path=cache_file_path,
                                   task_names=['p_np'],
                                   load=load,
                                   log_every=log_every,
                                   init_mask=True,
                                   n_jobs=n_jobs)

        self.load_full = False
        self.names = df['name'].tolist()
        self.names = [self.names[i] for i in self.valid_ids]
コード例 #2
0
 def __init__(self, name):
     self.name = name
     self.dir = get_download_dir()
     tgz_path = os.path.join(self.dir, '{}.tgz'.format(self.name))
     download(_downlaod_prefix + '{}.tgz'.format(self.name), tgz_path)
     self.dir = os.path.join(self.dir, self.name)
     extract_archive(tgz_path, self.dir)
コード例 #3
0
    def __init__(self, data, vocab, training=True):
        self.dir = get_download_dir()
        self.zip_file_path = '{}/jtnn.zip'.format(self.dir)
        download(_url, path=self.zip_file_path)
        extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
        print('Loading data...')
        if data in ['train', 'test']:
            data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
        else:
            data_file = data
        with open(data_file) as f:
            self.data = [line.strip("\r\n ").split()[0] for line in f]

        if vocab == 'zinc':
            self.vocab_file = '{}/jtnn/vocab.txt'.format(self.dir)
        elif vocab == 'guacamol':
            self.vocab_file = '{}/jtnn/vocab_guacamol.txt'.format(self.dir)
        else:
            self.vocab_file = vocab

        print('Loading finished.')
        print('\tNum samples:', len(self.data))
        print('\tVocab file:', self.vocab_file)
        self.training = training
        self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
コード例 #4
0
    def __init__(self,
                 smiles_to_graph=smiles_to_bigraph,
                 node_featurizer=None,
                 edge_featurizer=None,
                 load=True,
                 log_every=1000,
                 cache_file_path='lipophilicity_dglgraph.bin'):

        self._url = 'dataset/lipophilicity.zip'
        data_path = get_download_dir() + '/lipophilicity.zip'
        dir_path = get_download_dir() + '/lipophilicity'
        download(_get_dgl_url(self._url), path=data_path)
        extract_archive(data_path, dir_path)
        df = pd.read_csv(dir_path + '/Lipophilicity.csv')

        # ChEMBL ids
        self.chembl_ids = df['CMPD_CHEMBLID'].tolist()

        self.load_full = False

        super(Lipophilicity, self).__init__(df=df,
                                            smiles_to_graph=smiles_to_graph,
                                            node_featurizer=node_featurizer,
                                            edge_featurizer=edge_featurizer,
                                            smiles_column='smiles',
                                            cache_file_path=cache_file_path,
                                            task_names=['exp'],
                                            load=load,
                                            log_every=log_every,
                                            init_mask=False)
コード例 #5
0
ファイル: muv.py プロジェクト: zwvews/dgl-lifesci
    def __init__(self,
                 smiles_to_graph=smiles_to_bigraph,
                 node_featurizer=None,
                 edge_featurizer=None,
                 load=False,
                 log_every=1000,
                 cache_file_path='./muv_dglgraph.bin',
                 n_jobs=1):

        self._url = 'dataset/muv.zip'
        data_path = get_download_dir() + '/muv.zip'
        dir_path = get_download_dir() + '/muv'
        download(_get_dgl_url(self._url), path=data_path, overwrite=False)
        extract_archive(data_path, dir_path)
        df = pd.read_csv(dir_path + '/muv.csv')

        self.ids = df['mol_id'].tolist()
        self.load_full = False

        df = df.drop(columns=['mol_id'])

        super(MUV, self).__init__(df=df,
                                  smiles_to_graph=smiles_to_graph,
                                  node_featurizer=node_featurizer,
                                  edge_featurizer=edge_featurizer,
                                  smiles_column='smiles',
                                  cache_file_path=cache_file_path,
                                  load=load,
                                  log_every=log_every,
                                  init_mask=True,
                                  n_jobs=n_jobs)

        self.ids = [self.ids[i] for i in self.valid_ids]
コード例 #6
0
    def __init__(self, hidden_size, latent_size, depth, vocab=None, vocab_file=None):
        super(DGLJTNNVAE, self).__init__()
        if vocab is None:
            if vocab_file is None:
                default_dir = get_download_dir()
                vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
                zip_file_path = '{}/jtnn.zip'.format(default_dir)
                download(_get_dgl_url('dataset/jtnn.zip'), path=zip_file_path)
                extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))

            self.vocab = Vocab([x.strip("\r\n ") for x in open(vocab_file)])
        else:
            self.vocab = vocab

        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(self.vocab.size(), hidden_size)
        self.mpn = DGLMPN(hidden_size, depth)
        self.jtnn = DGLJTNNEncoder(self.vocab, hidden_size, self.embedding)
        self.decoder = DGLJTNNDecoder(
            self.vocab, hidden_size, latent_size // 2, self.embedding)
        self.jtmpn = DGLJTMPN(hidden_size, depth)

        self.T_mean = nn.Linear(hidden_size, latent_size // 2)
        self.T_var = nn.Linear(hidden_size, latent_size // 2)
        self.G_mean = nn.Linear(hidden_size, latent_size // 2)
        self.G_var = nn.Linear(hidden_size, latent_size // 2)

        self.n_nodes_total = 0
        self.n_passes = 0
        self.n_edges_total = 0
        self.n_tree_nodes_total = 0
コード例 #7
0
ファイル: vocab.py プロジェクト: cyhFlight/dgl-lifesci
    def __init__(self, file_path=None):
        if file_path is None:
            from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive

            default_dir = get_download_dir()
            vocab_file = '{}/jtvae/vocab.txt'.format(default_dir)
            zip_file_path = '{}/jtvae.zip'.format(default_dir)
            download(_get_dgl_url('dataset/jtvae.zip'), path=zip_file_path, overwrite=False)
            extract_archive(zip_file_path, '{}/jtvae'.format(default_dir))

            with open(vocab_file, 'r') as f:
                self.vocab = [x.strip("\r\n ") for x in f]
        else:
            # Prepare a vocabulary from scratch
            vocab = set()
            with open(file_path, 'r') as f:
                for line in f:
                    smiles = line.split()[0]
                    mol = MolTree(smiles)
                    for i in mol.nodes_dict:
                        vocab.add(mol.nodes_dict[i]['smiles'])
            self.vocab = list(vocab)

        self.vmap = {x: i for i, x in enumerate(self.vocab)}
        self.slots = [get_slots(smiles) for smiles in self.vocab]
コード例 #8
0
ファイル: uspto.py プロジェクト: jjhu94/dgl-1
    def __init__(self,
                 subset,
                 mol_to_graph=mol_to_bigraph,
                 node_featurizer=default_node_featurizer,
                 edge_featurizer=default_edge_featurizer,
                 atom_pair_featurizer=default_atom_pair_featurizer,
                 load=True):
        assert subset in ['train', 'val', 'test'], \
            'Expect subset to be "train" or "val" or "test", got {}'.format(subset)
        print('Preparing {} subset of USPTO'.format(subset))
        self._subset = subset
        if subset == 'val':
            subset = 'valid'

        self._url = 'dataset/uspto.zip'
        data_path = get_download_dir() + '/uspto.zip'
        extracted_data_path = get_download_dir() + '/uspto'
        download(_get_dgl_url(self._url), path=data_path)
        extract_archive(data_path, extracted_data_path)

        super(USPTO, self).__init__(
            raw_file_path=extracted_data_path + '/{}.txt'.format(subset),
            mol_graph_path=extracted_data_path + '/{}_mol_graphs.bin'.format(subset),
            mol_to_graph=mol_to_graph,
            node_featurizer=node_featurizer,
            edge_featurizer=edge_featurizer,
            atom_pair_featurizer=atom_pair_featurizer,
            load=load)
コード例 #9
0
    def __init__(self,
                 smiles_to_graph=smiles_to_bigraph,
                 node_featurizer=None,
                 edge_featurizer=None,
                 load=True,
                 log_every=1000,
                 cache_file_path='freesolv_dglgraph.bin'):

        self._url = 'dataset/FreeSolv.zip'
        data_path = get_download_dir() + '/FreeSolv.zip'
        dir_path = get_download_dir() + '/FreeSolv'
        download(_get_dgl_url(self._url), path=data_path)
        extract_archive(data_path, dir_path)
        df = pd.read_csv(dir_path + '/SAMPL.csv')

        # Iupac names
        self.iupac_names = df['iupac'].tolist()
        # Calculated hydration free energy
        self.calc_energy = df['calc'].tolist()

        self.load_full = False

        super(FreeSolv, self).__init__(df=df,
                                       smiles_to_graph=smiles_to_graph,
                                       node_featurizer=node_featurizer,
                                       edge_featurizer=edge_featurizer,
                                       smiles_column='smiles',
                                       cache_file_path=cache_file_path,
                                       task_names=['expt'],
                                       load=load,
                                       log_every=log_every,
                                       init_mask=False)
コード例 #10
0
ファイル: pdbbind.py プロジェクト: zwvews/dgl-lifesci
    def __init__(self, subset, load_binding_pocket=True, sanitize=False, calc_charges=False,
                 remove_hs=False, use_conformation=True,
                 construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
                 zero_padding=True, num_processes=64):
        self.task_names = ['-logKd/Ki']
        self.n_tasks = len(self.task_names)

        self._url = 'dataset/pdbbind_v2015.tar.gz'
        root_dir_path = get_download_dir()
        data_path = root_dir_path + '/pdbbind_v2015.tar.gz'
        extracted_data_path = root_dir_path + '/pdbbind_v2015'
        download(_get_dgl_url(self._url), path=data_path, overwrite=False)
        extract_archive(data_path, extracted_data_path)

        if subset == 'core':
            index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013'
        elif subset == 'refined':
            index_label_file = extracted_data_path + '/v2015/INDEX_refined_data.2015'
        else:
            raise ValueError(
                'Expect the subset_choice to be either '
                'core or refined, got {}'.format(subset))

        self._preprocess(extracted_data_path, index_label_file, load_binding_pocket,
                         sanitize, calc_charges, remove_hs, use_conformation,
                         construct_graph_and_featurize, zero_padding, num_processes)
コード例 #11
0
ファイル: sider.py プロジェクト: zwvews/dgl-lifesci
    def __init__(self,
                 smiles_to_graph=smiles_to_bigraph,
                 node_featurizer=None,
                 edge_featurizer=None,
                 load=False,
                 log_every=1000,
                 cache_file_path='./sider_dglgraph.bin',
                 n_jobs=1):

        self._url = 'dataset/sider.zip'
        data_path = get_download_dir() + '/sider.zip'
        dir_path = get_download_dir() + '/sider'
        download(_get_dgl_url(self._url), path=data_path, overwrite=False)
        extract_archive(data_path, dir_path)
        df = pd.read_csv(dir_path + '/sider.csv')

        super(SIDER, self).__init__(df=df,
                                    smiles_to_graph=smiles_to_graph,
                                    node_featurizer=node_featurizer,
                                    edge_featurizer=edge_featurizer,
                                    smiles_column='smiles',
                                    cache_file_path=cache_file_path,
                                    load=load,
                                    log_every=log_every,
                                    init_mask=True,
                                    n_jobs=n_jobs)
コード例 #12
0
    def __init__(self, hidden_size, latent_size, depth, vocab_file=None):
        super(DGLJTNNVAE, self).__init__()

        if vocab_file is None:
            default_dir = get_download_dir()
            vocab_file = '{}/jtvae/{}.txt'.format(default_dir, 'vocab')
            zip_file_path = '{}/jtvae.zip'.format(default_dir)
            download(_get_dgl_url('dataset/jtvae.zip'), path=zip_file_path)
            extract_archive(zip_file_path, '{}/jtvae'.format(default_dir))

        with open(vocab_file, 'r') as f:
            self.vocab = Vocab([x.strip("\r\n ") for x in f])

        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(self.vocab.size(), hidden_size)
        self.mpn = DGLMPN(hidden_size, depth)
        self.jtnn = DGLJTNNEncoder(self.vocab, hidden_size, self.embedding)
        self.decoder = DGLJTNNDecoder(self.vocab, hidden_size,
                                      latent_size // 2, self.embedding)
        self.jtmpn = DGLJTMPN(hidden_size, depth)

        self.T_mean = nn.Linear(hidden_size, latent_size // 2)
        self.T_var = nn.Linear(hidden_size, latent_size // 2)
        self.G_mean = nn.Linear(hidden_size, latent_size // 2)
        self.G_var = nn.Linear(hidden_size, latent_size // 2)

        self.atom_featurizer_enc = get_atom_featurizer_enc()
        self.bond_featurizer_enc = get_bond_featurizer_enc()
        self.atom_featurizer_dec = get_atom_featurizer_dec()
        self.bond_featurizer_dec = get_bond_featurizer_dec()
コード例 #13
0
    def __init__(self, data, vocab, training=True):
        dir = get_download_dir()

        _url = _get_dgl_url('dataset/jtnn.zip')
        zip_file_path = '{}/jtnn.zip'.format(dir)
        download(_url, path=zip_file_path)
        extract_archive(zip_file_path, '{}/jtnn'.format(dir))

        print('Loading data...')
        if data in ['train', 'test']:
            # ZINC subset
            data_file = '{}/jtnn/{}.txt'.format(dir, data)
        else:
            # New dataset
            data_file = data
        with open(data_file) as f:
            self.data = [line.strip("\r\n ").split()[0] for line in f]
        self.vocab = vocab

        print('Loading finished')
        print('\t# samples:', len(self.data))
        self.training = training

        self.atom_featurizer_enc = get_atom_featurizer_enc()
        self.bond_featurizer_enc = get_bond_featurizer_enc()
        self.atom_featurizer_dec = get_atom_featurizer_dec()
        self.bond_featurizer_dec = get_bond_featurizer_dec()
コード例 #14
0
 def _download(self):
     download_dir = get_download_dir()
     zip_file_path = os.path.join(download_dir, "tu_{}.zip".format(self.name))
     download(self._url.format(self.name), path=zip_file_path)
     extract_dir = os.path.join(download_dir, "tu_{}".format(self.name))
     extract_archive(zip_file_path, extract_dir)
     return extract_dir
コード例 #15
0
ファイル: test_datasets.py プロジェクト: zwvews/dgl-lifesci
def test_jtvae():
    # Test DGLMolTree
    smiles = 'CC1([C@@H](N2[C@H](S1)[C@@H](C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C'
    tree = DGLMolTree(smiles)
    assert tree.treesize() == 17
    tree.assemble()
    assert tree._recover_node(0, tree.mol) == 'C[CH3:15]'
    tree.recover()

    # Test JTVAEDataset
    smiles = [
        'CCCCCCC1=NN2C(=N)/C(=C\c3cc(C)n(-c4ccc(C)cc4C)c3C)C(=O)N=C2S1',
        'COCC[C@@H](C)C(=O)N(C)Cc1ccc(O)cc1'
    ]
    with open('data.txt', 'w') as f:
        for smi in smiles:
            f.write(smi + '\n')

    default_dir = get_download_dir()
    vocab_file = '{}/jtnn/{}.txt'.format(default_dir, 'vocab')
    zip_file_path = '{}/jtnn.zip'.format(default_dir)
    download(_get_dgl_url('dataset/jtnn.zip'),
             path=zip_file_path,
             overwrite=False)
    extract_archive(zip_file_path, '{}/jtnn'.format(default_dir))

    with open(vocab_file, 'r') as f:
        vocab = Vocab([x.strip("\r\n ") for x in f])
    dataset = JTVAEDataset('data.txt', vocab)
    assert len(dataset) == 2
    assert set(dataset[0].keys()) == {
        'cand_graphs', 'mol_graph', 'mol_tree', 'stereo_cand_graphs',
        'stereo_cand_label', 'tree_mess_src_e', 'tree_mess_tgt_e',
        'tree_mess_tgt_n', 'wid'
    }
    dataset.training = False
    assert set(dataset[0].keys()) == {'mol_graph', 'mol_tree', 'wid'}

    dataset.training = True
    collate_fn = JTVAECollator(training=True)
    loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
    for _, batch_data in enumerate(loader):
        assert set(batch_data.keys()) == {
            'cand_batch_idx', 'cand_graph_batch', 'mol_graph_batch',
            'mol_trees', 'stereo_cand_batch_idx', 'stereo_cand_graph_batch',
            'stereo_cand_labels', 'stereo_cand_lengths', 'tree_mess_src_e',
            'tree_mess_tgt_e', 'tree_mess_tgt_n'
        }

    dataset.training = False
    collate_fn = JTVAECollator(training=False)
    loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
    for _, batch_data in enumerate(loader):
        assert set(batch_data.keys()) == {'mol_graph_batch', 'mol_trees'}

    remove_file('data.txt')
    remove_file(zip_file_path)
    remove_dir(default_dir + '/jtnn')
コード例 #16
0
ファイル: dataset.py プロジェクト: sailfish009/GraphNorm
 def _download(self):
     download_dir = get_download_dir()
     zip_file_path = os.path.join(
         download_dir, "{}.zip".format(self.ds_name))
     # TODO move to dgl host _get_dgl_url
     download(_url, path=zip_file_path)
     extract_dir = os.path.join(
         download_dir, "{}".format(self.ds_name))
     extract_archive(zip_file_path, extract_dir)
     return extract_dir
コード例 #17
0
def _download_babi_data():
    download_dir = get_download_dir()
    zip_file_path = os.path.join(download_dir, 'babi_data.zip')

    data_url = _get_dgl_url('models/ggnn_babi_data.zip')
    download(data_url, path=zip_file_path)

    extract_dir = os.path.join(download_dir, 'babi_data')
    if not os.path.exists(extract_dir):
        extract_archive(zip_file_path, extract_dir)
コード例 #18
0
ファイル: QM9Edge.py プロジェクト: Jack-XHP/DGL_QM9EDGE
    def download(self):
        if not os.path.exists(os.path.join(self.raw_dir, "gdb9.sdf.csv")):
            file_path = download(self.raw_url, self.raw_dir)
            extract_archive(file_path, self.raw_dir, overwrite=True)
            os.unlink(file_path)

        if not os.path.exists(os.path.join(self.raw_dir,
                                           "uncharacterized.txt")):
            file_path = download(self.raw_url2, self.raw_dir)
            os.replace(os.path.join(self.raw_dir, '3195404'),
                       os.path.join(self.raw_dir, 'uncharacterized.txt'))
コード例 #19
0
ファイル: Alchemy_dataset.py プロジェクト: lyf35/Alchemy
    def __init__(self, mode='dev', transform=None):
        assert mode in ['dev', 'valid',
                        'test'], "mode should be dev/valid/test"
        self.mode = mode
        self.transform = transform
        self.file_dir = pathlib.Path(get_download_dir(), mode)
        self.zip_file_path = pathlib.Path(get_download_dir(), '%s.zip' % mode)
        download(_urls['Alchemy'] + "%s.zip" % mode,
                 path=str(self.zip_file_path))
        extract_archive(str(self.zip_file_path), str(self.file_dir))

        self._load()
コード例 #20
0
ファイル: knowledge_graph.py プロジェクト: Lee-zix/RE-GCN
    def __init__(self, name, dir=None):
        self.name = name
        if dir:
            self.dir = dir
            self.dir = os.path.join(self.dir, self.name)

        else:
            self.dir = get_download_dir()
            tgz_path = os.path.join(self.dir, '{}.tar.gz'.format(self.name))
            download(_downlaod_prefix + '{}.tgz'.format(self.name), tgz_path)
            self.dir = os.path.join(self.dir, self.name)
            extract_archive(tgz_path, self.dir)
        print(self.dir)
コード例 #21
0
 def __init__(self, mode='train', vocab_file=None):
     self.mode = mode
     self.dir = get_download_dir()
     self.zip_file_path='{}/sst.zip'.format(self.dir)
     self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
     self.pretrained_emb = None
     self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
     download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
     extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
     self.trees = []
     self.num_classes = 5
     print('Preprocessing...')
     self._load()
     print('Dataset creation finished. #Trees:', len(self.trees))
コード例 #22
0
ファイル: datautils.py プロジェクト: zhp510730568/dgl
 def __init__(self, data, vocab, training=True):
     self.dir = get_download_dir()
     self.zip_file_path = '{}/jtnn.zip'.format(self.dir)
     download(_url, path=self.zip_file_path)
     extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
     print('Loading data...')
     data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
     with open(data_file) as f:
         self.data = [line.strip("\r\n ").split()[0] for line in f]
     self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
     print('Loading finished.')
     print('\tNum samples:', len(self.data))
     print('\tVocab file:', self.vocab_file)
     self.training = training
     self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
コード例 #23
0
def test_acnn():
    remove_dir('tmp1')
    remove_dir('tmp2')

    url = _get_dgl_url('dgllife/example_mols.tar.gz')
    local_path = 'tmp1/example_mols.tar.gz'
    download(url, path=local_path)
    extract_archive(local_path, 'tmp2')

    pocket_mol, pocket_coords = load_molecule(
        'tmp2/example_mols/example.pdb', remove_hs=True)
    ligand_mol, ligand_coords = load_molecule(
        'tmp2/example_mols/example.pdbqt', remove_hs=True)

    remove_dir('tmp1')
    remove_dir('tmp2')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g1 = ACNN_graph_construction_and_featurization(ligand_mol,
                                                   pocket_mol,
                                                   ligand_coords,
                                                   pocket_coords)

    model = ACNN()
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])

    model = ACNN(hidden_sizes=[1, 2],
                 weight_init_stddevs=[1, 1],
                 dropouts=[0.1, 0.],
                 features_to_use=torch.tensor([6., 8.]),
                 radial=[[12.0], [0.0, 2.0], [4.0]])
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])
コード例 #24
0
    def __init__(self,
                 smiles_to_graph=smiles_to_bigraph,
                 node_featurizer=None,
                 edge_featurizer=None,
                 load=True,
                 log_every=1000,
                 cache_file_path='esol_dglgraph.bin'):

        self._url = 'dataset/ESOL.zip'
        data_path = get_download_dir() + '/ESOL.zip'
        dir_path = get_download_dir() + '/ESOL'
        download(_get_dgl_url(self._url), path=data_path)
        extract_archive(data_path, dir_path)
        df = pd.read_csv(dir_path + '/delaney-processed.csv')

        # Compound names in PubChem
        self.compound_names = df['Compound ID'].tolist()
        # Estimated solubility
        self.estimated_solubility = df['ESOL predicted log solubility in mols per litre'].tolist()
        # Minimum atom degree
        self.min_degree = df['Minimum Degree'].tolist()
        # Molecular weight
        self.mol_weight = df['Molecular Weight'].tolist()
        # Number of H-Bond Donors
        self.num_h_bond_donors = df['Number of H-Bond Donors'].tolist()
        # Number of rings
        self.num_rings = df['Number of Rings'].tolist()
        # Number of rotatable bonds
        self.num_rotatable_bonds = df['Number of Rotatable Bonds'].tolist()
        # Polar Surface Area
        self.polar_surface_area = df['Polar Surface Area'].tolist()

        self.load_full = False

        super(ESOL, self).__init__(df=df,
                                   smiles_to_graph=smiles_to_graph,
                                   node_featurizer=node_featurizer,
                                   edge_featurizer=edge_featurizer,
                                   smiles_column='smiles',
                                   cache_file_path=cache_file_path,
                                   task_names=['measured log solubility in mols per litre'],
                                   load=load,
                                   log_every=log_every,
                                   init_mask=False)
コード例 #25
0
ファイル: test_io_utils.py プロジェクト: zwvews/dgl-lifesci
def test_load_molecule():
    remove_dir('tmp1')
    remove_dir('tmp2')

    url = _get_dgl_url('dgllife/example_mols.tar.gz')
    local_path = 'tmp1/example_mols.tar.gz'
    download(url, path=local_path)
    extract_archive(local_path, 'tmp2')

    load_molecule('tmp2/example_mols/example.sdf')
    load_molecule('tmp2/example_mols/example.mol2',
                  use_conformation=False,
                  sanitize=True)
    load_molecule('tmp2/example_mols/example.pdbqt', calc_charges=True)
    mol, _ = load_molecule('tmp2/example_mols/example.pdb', remove_hs=True)
    assert mol.GetNumAtoms() == mol.GetNumHeavyAtoms()

    remove_dir('tmp1')
    remove_dir('tmp2')
コード例 #26
0
def create_generative_model(model_name):
    """Create a model.

    Parameters
    ----------
    model_name : str
        Name for the model.

    Returns
    -------
    Created model
    """
    if model_name.startswith('DGMG'):
        if model_name.startswith('DGMG_ChEMBL'):
            atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
        elif model_name.startswith('DGMG_ZINC'):
            atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
        bond_types = [
            Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE
        ]

        return DGMG(atom_types=atom_types,
                    bond_types=bond_types,
                    node_hidden_size=128,
                    num_prop_rounds=2,
                    dropout=0.2)

    elif model_name == "JTNN_ZINC":
        default_dir = get_download_dir()
        vocab_file = '{}/jtvae/{}.txt'.format(default_dir, 'vocab')
        if not os.path.exists(vocab_file):
            zip_file_path = '{}/jtvae.zip'.format(default_dir)
            download(_get_dgl_url('dataset/jtvae.zip'), path=zip_file_path)
            extract_archive(zip_file_path, '{}/jtvae'.format(default_dir))
        return DGLJTNNVAE(vocab_file=vocab_file,
                          depth=3,
                          hidden_size=450,
                          latent_size=56)

    else:
        return None
コード例 #27
0
ファイル: jtvae.py プロジェクト: runom/dgl-lifesci
    def __init__(self, subset, vocab, cache=False):
        dir = get_download_dir()
        _url = _get_dgl_url('dataset/jtvae.zip')
        zip_file_path = '{}/jtvae.zip'.format(dir)
        download(_url, path=zip_file_path, overwrite=False)
        extract_archive(zip_file_path, '{}/jtvae'.format(dir))

        if subset == 'train':
            super(JTVAEZINC,
                  self).__init__(data_file='{}/jtvae/train.txt'.format(dir),
                                 vocab=vocab,
                                 cache=cache)
        elif subset == 'test':
            super(JTVAEZINC,
                  self).__init__(data_file='{}/jtvae/test.txt'.format(dir),
                                 vocab=vocab,
                                 cache=cache,
                                 training=False)
        else:
            raise ValueError(
                "Expect subset to be 'train' or 'test', got {}".format(subset))
コード例 #28
0
    def __init__(self,
                 name,
                 device,
                 mix_cpu_gpu=False,
                 use_one_hot_fea=True,
                 symm=True,
                 test_ratio=0.1,
                 valid_ratio=0.1,
                 sparse_ratio=0,
                 sample_rate=3):
        self._name = name
        self._device = device
        self._symm = symm
        self._test_ratio = test_ratio
        self._valid_ratio = valid_ratio
        self._dir = os.path.join(_paths[self._name])
        self.sample_rate = sample_rate
        print(self._name[0:5])
        if self._name in ['ml-100k', 'ml-1m', 'ml-10m']:
            # download and extract
            download_dir = get_download_dir()
            print("download_dir: ", download_dir)
            zip_file_path = '{}/{}.zip'.format(download_dir, name)
            download(_urls[name], path=zip_file_path)
            extract_archive(zip_file_path, '{}/{}'.format(download_dir, name))
            if name == 'ml-10m':
                root_folder = 'ml-10M100K'
            else:
                root_folder = name
            self._dir = os.path.join(download_dir, name, root_folder)
            print("Starting processing {} ...".format(self._name))
            self._load_raw_user_info()
            self._load_raw_movie_info()
            print('......')
            if self._name == 'ml-100k':
                self.all_train_rating_info = self._load_raw_rates(
                    os.path.join(self._dir, 'u1.base'), '\t')
                self.test_rating_info = self._load_raw_rates(
                    os.path.join(self._dir, 'u1.test'), '\t')
                self.all_rating_info = pd.concat(
                    [self.all_train_rating_info, self.test_rating_info])
            elif self._name == 'ml-1m' or self._name == 'ml-10m':
                self.all_rating_info = self._load_raw_rates(
                    os.path.join(self._dir, 'ratings.dat'), '::')
                num_test = int(
                    np.ceil(self.all_rating_info.shape[0] * self._test_ratio))
                shuffled_idx = np.random.permutation(
                    self.all_rating_info.shape[0])
                self.test_rating_info = self.all_rating_info.iloc[
                    shuffled_idx[:num_test]]
                self.all_train_rating_info = self.all_rating_info.iloc[
                    shuffled_idx[num_test:]]
            else:
                raise NotImplementedError
            print('......')
            num_valid = int(
                np.ceil(self.all_train_rating_info.shape[0] *
                        self._valid_ratio))
            shuffled_idx = np.random.permutation(
                self.all_train_rating_info.shape[0])
            self.valid_rating_info = self.all_train_rating_info.iloc[
                shuffled_idx[:num_valid]]
            self.train_rating_info = self.all_train_rating_info.iloc[
                shuffled_idx[num_valid:]]
            self.possible_rating_values = np.append(
                np.unique(self.train_rating_info["rating"].values), 0)
        elif self._name in ['Tmall', 'Tmall_small'
                            ] or self._name[0:5] == 'Tmall':
            #self.all_rating_info, M = self._load_tmall(os.path.join(_paths[self._name]))
            #print(self._name[0:5])
            self.all_rating_info = self._load_raw_rates_Tmall(
                os.path.join(_paths[self._name]), ' ')
            print(self.all_rating_info)
            num_test = int(
                np.ceil(self.all_rating_info.shape[0] *
                        (1 - self._test_ratio)))
            shuffled_idx = np.random.permutation(self.all_rating_info.shape[0])
            #self.test_rating_info = self.all_rating_info.iloc[shuffled_idx[: num_test]]
            #self.all_train_rating_info = self.all_rating_info.iloc[shuffled_idx[num_test: ]]
            self.test_rating_info = self.all_rating_info.iloc[num_test:]
            self.all_train_rating_info = self.all_rating_info.iloc[:num_test]
            print("self.all_train_rating_info")
            print(self.all_train_rating_info)

            user_list = pd.unique(self.all_rating_info["user_id"].values)
            item_list = pd.unique(self.all_rating_info["movie_id"].values)

            #print("*******", user_list)

            user_nodes, item_nodes = user_list, item_list
            print('......')
            num_valid = int(
                np.ceil(self.all_train_rating_info.shape[0] *
                        self._valid_ratio))
            #shuffled_idx = np.random.permutation(self.all_train_rating_info.shape[0])
            #self.valid_rating_info = self.all_train_rating_info.iloc[shuffled_idx[: num_valid]]
            #self.train_rating_info = self.all_train_rating_info.iloc[shuffled_idx[num_valid: ]]
            self.valid_rating_info = self.all_train_rating_info.iloc[:
                                                                     num_valid]
            self.train_rating_info = self.all_train_rating_info.iloc[
                num_valid:]
            shuffled_idx = np.random.permutation(
                self.train_rating_info.shape[0])
            self.train_rating_info = self.train_rating_info.iloc[shuffled_idx]

            self.possible_rating_values = np.append(
                np.unique(self.train_rating_info["rating"].values), 0)
            #print(self.possible_rating_values)
        else:
            raise NotImplementedError

        self.user_poll = set(pd.unique(self.all_rating_info["user_id"].values))
        self.item_poll = set(pd.unique(
            self.all_rating_info["movie_id"].values))

        self.negatives = self.sample_negative(self.train_rating_info,
                                              self.sample_rate,
                                              random_number=1)

        print("All rating pairs : {}".format(self.all_rating_info.shape[0]))
        print("\tAll train rating pairs : {}".format(
            self.all_train_rating_info.shape[0]))
        print("\t\tTrain rating pairs : {}".format(
            self.train_rating_info.shape[0]))
        print("\t\tValid rating pairs : {}".format(
            self.valid_rating_info.shape[0]))
        print("\tTest rating pairs : {}".format(
            self.test_rating_info.shape[0]))

        if self._name in ['ml-100k', 'ml-1m', 'ml-10m']:
            self.user_info = self._drop_unseen_nodes(
                orign_info=self.user_info,
                cmp_col_name="id",
                reserved_ids_set=set(self.all_rating_info["user_id"].values),
                label="user")
            self.movie_info = self._drop_unseen_nodes(
                orign_info=self.movie_info,
                cmp_col_name="id",
                reserved_ids_set=set(self.all_rating_info["movie_id"].values),
                label="movie")

            # Map user/movie to the global id
            self.global_user_id_map = {
                ele: i
                for i, ele in enumerate(self.user_info['id'])
            }
            self.global_movie_id_map = {
                ele: i
                for i, ele in enumerate(self.movie_info['id'])
            }
        elif self._name in [
                'flixster', 'douban', 'yahoo_music', 'Tmall', 'Tmall_small'
        ] or self._name[0:5] == 'Tmall':
            self.global_user_id_map = bidict({})
            self.global_movie_id_map = bidict({})
            # max_uid = 0
            # max_vid = 0
            print("user and item number:")
            # print(user_nodes)
            # print(item_nodes)
            for i in range(len(user_nodes)):
                self.global_user_id_map[user_nodes[i]] = i
            for i in range(len(item_nodes)):
                self.global_movie_id_map[item_nodes[i]] = i
        else:
            raise NotImplementedError

        print('Total user number = {}, movie number = {}'.format(
            len(self.global_user_id_map), len(self.global_movie_id_map)))
        self._num_user = len(self.global_user_id_map)
        self._num_movie = len(self.global_movie_id_map)
        ### Generate features
        if use_one_hot_fea:
            self.user_feature = None
            self.movie_feature = None
        else:
            raise NotImplementedError

        if self.user_feature is None:
            self.user_feature_shape = (self.num_user,
                                       self.num_user + self.num_movie + 3)
            self.movie_feature_shape = (self.num_movie,
                                        self.num_user + self.num_movie + 3)
            if mix_cpu_gpu:
                self.user_feature = th.cat([
                    th.Tensor(list(range(3, self.num_user + 3))).reshape(
                        -1, 1),
                    th.zeros([self.num_user, 1]) + 1,
                    th.zeros([self.num_user, 1])
                ], 1)
                self.movie_feature = th.cat([
                    th.Tensor(list(range(3, self.num_movie + 3))).reshape(
                        -1, 1),
                    th.ones([self.num_movie, 1]) + 1,
                    th.zeros([self.num_movie, 1])
                ], 1)
                #self.movie_feature = th.cat([th.Tensor(list(range(self.num_user+3, self.num_user + self.num_movie + 3))).reshape(-1, 1), th.ones([self.num_movie, 1])+1, th.zeros([self.num_movie, 1])], 1)
            else:
                self.user_feature = th.cat([
                    th.Tensor(list(range(3, self.num_user + 3))).reshape(
                        -1, 1),
                    th.zeros([self.num_user, 1]) + 1,
                    th.zeros([self.num_user, 1])
                ], 1).to(self._device)
                self.movie_feature = th.cat([
                    th.Tensor(
                        list(
                            range(self.num_user + 3, self.num_user +
                                  self.num_movie + 3))).reshape(-1, 1),
                    th.ones([self.num_movie, 1]) + 1,
                    th.zeros([self.num_movie, 1])
                ], 1).to(self._device)
        else:
            raise NotImplementedError
        #print(self.user_feature.shape)
        info_line = "Feature dim: "
        info_line += "\nuser: {}".format(self.user_feature_shape)
        info_line += "\nmovie: {}".format(self.movie_feature_shape)
        print(info_line)

        all_train_rating_pairs, all_train_rating_values = self._generate_pair_value(
            self.all_train_rating_info)
        train_rating_pairs, train_rating_values = self._generate_pair_value(
            self.train_rating_info)
        valid_rating_pairs, valid_rating_values = self._generate_pair_value(
            self.valid_rating_info)
        test_rating_pairs, test_rating_values = self._generate_pair_value(
            self.test_rating_info)

        def _make_labels(ratings):
            labels = th.LongTensor(
                np.searchsorted(self.possible_rating_values,
                                ratings)).to(device)
            return labels

        self.train_enc_graph = self._generate_enc_graph(train_rating_pairs,
                                                        train_rating_values,
                                                        add_support=True)
        self.train_dec_graph = self._generate_dec_graph(train_rating_pairs)
        self.train_labels = _make_labels(train_rating_values)
        self.train_truths = th.FloatTensor(train_rating_values).to(device)

        self.valid_enc_graph = self.train_enc_graph
        self.valid_dec_graph = self._generate_dec_graph(valid_rating_pairs)
        self.valid_labels = _make_labels(valid_rating_values)
        self.valid_truths = th.FloatTensor(valid_rating_values).to(device)

        self.test_enc_graph = self._generate_enc_graph(all_train_rating_pairs,
                                                       all_train_rating_values,
                                                       add_support=True)
        self.test_dec_graph = self._generate_dec_graph(test_rating_pairs)
        self.test_labels = _make_labels(test_rating_values)
        self.test_truths = th.FloatTensor(test_rating_values).to(device)

        #创建一个用来测试召回数据的图
        #self.test_recall_labels = _make_labels(self.test_rating_info)
        valid_recall_pair, valid_rating_matrix = self._generate_pair_value_for_recall(
            self.valid_rating_info)
        self.valid_rating_matrix = th.FloatTensor(valid_rating_matrix).to(
            device)
        self.valid_recall_dec_graph = self._generate_dec_graph(
            valid_recall_pair)

        test_recall_pair, test_rating_matrix = self._generate_pair_value_for_recall(
            self.test_rating_info)
        self.test_rating_matrix = th.FloatTensor(test_rating_matrix).to(device)
        self.test_recall_dec_graph = self._generate_dec_graph(test_recall_pair)

        def _npairs(graph):
            rst = 0
            for r in self.possible_rating_values:
                r = to_etype_name(r)
                rst += graph.number_of_edges(str(r))
            return rst

        print("Train enc graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.train_enc_graph.number_of_nodes('user'),
            self.train_enc_graph.number_of_nodes('movie'),
            _npairs(self.train_enc_graph)))
        print("Train dec graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.train_dec_graph.number_of_nodes('user'),
            self.train_dec_graph.number_of_nodes('movie'),
            self.train_dec_graph.number_of_edges()))
        print("Valid enc graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.valid_enc_graph.number_of_nodes('user'),
            self.valid_enc_graph.number_of_nodes('movie'),
            _npairs(self.valid_enc_graph)))
        print("Valid dec graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.valid_dec_graph.number_of_nodes('user'),
            self.valid_dec_graph.number_of_nodes('movie'),
            self.valid_dec_graph.number_of_edges()))
        print("Test enc graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.test_enc_graph.number_of_nodes('user'),
            self.test_enc_graph.number_of_nodes('movie'),
            _npairs(self.test_enc_graph)))
        print("Test dec graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.test_dec_graph.number_of_nodes('user'),
            self.test_dec_graph.number_of_nodes('movie'),
            self.test_dec_graph.number_of_edges()))
コード例 #29
0
    def __init__(self,
                 name,
                 device,
                 mix_cpu_gpu=False,
                 use_one_hot_fea=False,
                 symm=True,
                 test_ratio=0.1,
                 valid_ratio=0.1):
        self._name = name
        self._device = device
        self._symm = symm
        self._test_ratio = test_ratio
        self._valid_ratio = valid_ratio
        # download and extract
        download_dir = get_download_dir()
        zip_file_path = '{}/{}.zip'.format(download_dir, name)
        download(_urls[name], path=zip_file_path)
        extract_archive(zip_file_path, '{}/{}'.format(download_dir, name))
        if name == 'ml-10m':
            root_folder = 'ml-10M100K'
        else:
            root_folder = name
        self._dir = os.path.join(download_dir, name, root_folder)
        print("Starting processing {} ...".format(self._name))
        self._load_raw_user_info()
        self._load_raw_movie_info()
        print('......')
        if self._name == 'ml-100k':
            self.all_train_rating_info = self._load_raw_rates(
                os.path.join(self._dir, 'u1.base'), '\t')
            self.test_rating_info = self._load_raw_rates(
                os.path.join(self._dir, 'u1.test'), '\t')
            self.all_rating_info = pd.concat(
                [self.all_train_rating_info, self.test_rating_info])
        elif self._name == 'ml-1m' or self._name == 'ml-10m':
            self.all_rating_info = self._load_raw_rates(
                os.path.join(self._dir, 'ratings.dat'), '::')
            num_test = int(
                np.ceil(self.all_rating_info.shape[0] * self._test_ratio))
            shuffled_idx = np.random.permutation(self.all_rating_info.shape[0])
            self.test_rating_info = self.all_rating_info.iloc[
                shuffled_idx[:num_test]]
            self.all_train_rating_info = self.all_rating_info.iloc[
                shuffled_idx[num_test:]]
        else:
            raise NotImplementedError
        print('......')
        num_valid = int(
            np.ceil(self.all_train_rating_info.shape[0] * self._valid_ratio))
        shuffled_idx = np.random.permutation(
            self.all_train_rating_info.shape[0])
        self.valid_rating_info = self.all_train_rating_info.iloc[
            shuffled_idx[:num_valid]]
        self.train_rating_info = self.all_train_rating_info.iloc[
            shuffled_idx[num_valid:]]
        self.possible_rating_values = np.unique(
            self.train_rating_info["rating"].values)

        print("All rating pairs : {}".format(self.all_rating_info.shape[0]))
        print("\tAll train rating pairs : {}".format(
            self.all_train_rating_info.shape[0]))
        print("\t\tTrain rating pairs : {}".format(
            self.train_rating_info.shape[0]))
        print("\t\tValid rating pairs : {}".format(
            self.valid_rating_info.shape[0]))
        print("\tTest rating pairs  : {}".format(
            self.test_rating_info.shape[0]))

        print("AAA")
        print(len(self.user_info))
        print(self.user_info)
        self.user_info = self._drop_unseen_nodes(
            orign_info=self.user_info,
            cmp_col_name="id",
            reserved_ids_set=set(self.all_rating_info["user_id"].values),
            label="user")
        self.movie_info = self._drop_unseen_nodes(
            orign_info=self.movie_info,
            cmp_col_name="id",
            reserved_ids_set=set(self.all_rating_info["movie_id"].values),
            label="movie")
        print(len(self.user_info))
        print(self.user_info)
        # Map user/movie to the global id
        self.global_user_id_map = {
            ele: i
            for i, ele in enumerate(self.user_info['id'])
        }
        print(self.global_user_id_map)
        self.global_movie_id_map = {
            ele: i
            for i, ele in enumerate(self.movie_info['id'])
        }
        print('Total user number = {}, movie number = {}'.format(
            len(self.global_user_id_map), len(self.global_movie_id_map)))
        self._num_user = len(self.global_user_id_map)
        self._num_movie = len(self.global_movie_id_map)

        ### Generate features
        if use_one_hot_fea:
            self.user_feature = None
            self.movie_feature = None
        else:
            # if mix_cpu_gpu, we put features in CPU
            if mix_cpu_gpu:
                self.user_feature = th.FloatTensor(self._process_user_fea())
                self.movie_feature = th.FloatTensor(self._process_movie_fea())
            else:
                self.user_feature = th.FloatTensor(
                    self._process_user_fea()).to(self._device)
                self.movie_feature = th.FloatTensor(
                    self._process_movie_fea()).to(self._device)
        if self.user_feature is None:
            self.user_feature_shape = (self.num_user, self.num_user)
            self.movie_feature_shape = (self.num_movie, self.num_movie)
        else:
            self.user_feature_shape = self.user_feature.shape
            self.movie_feature_shape = self.movie_feature.shape
        info_line = "Feature dim: "
        info_line += "\nuser: {}".format(self.user_feature_shape)
        info_line += "\nmovie: {}".format(self.movie_feature_shape)
        print(info_line)

        all_train_rating_pairs, all_train_rating_values = self._generate_pair_value(
            self.all_train_rating_info)
        train_rating_pairs, train_rating_values = self._generate_pair_value(
            self.train_rating_info)
        valid_rating_pairs, valid_rating_values = self._generate_pair_value(
            self.valid_rating_info)
        test_rating_pairs, test_rating_values = self._generate_pair_value(
            self.test_rating_info)

        def _make_labels(ratings):
            labels = th.LongTensor(
                np.searchsorted(self.possible_rating_values,
                                ratings)).to(device)
            return labels

        self.train_enc_graph = self._generate_enc_graph(train_rating_pairs,
                                                        train_rating_values,
                                                        add_support=True)
        self.train_dec_graph = self._generate_dec_graph(train_rating_pairs)
        self.train_labels = _make_labels(train_rating_values)
        self.train_truths = th.FloatTensor(train_rating_values).to(device)

        self.valid_enc_graph = self.train_enc_graph
        self.valid_dec_graph = self._generate_dec_graph(valid_rating_pairs)
        self.valid_labels = _make_labels(valid_rating_values)
        self.valid_truths = th.FloatTensor(valid_rating_values).to(device)

        self.test_enc_graph = self._generate_enc_graph(all_train_rating_pairs,
                                                       all_train_rating_values,
                                                       add_support=True)
        self.test_dec_graph = self._generate_dec_graph(test_rating_pairs)
        self.test_labels = _make_labels(test_rating_values)
        self.test_truths = th.FloatTensor(test_rating_values).to(device)

        def _npairs(graph):
            rst = 0
            for r in self.possible_rating_values:
                r = to_etype_name(r)
                rst += graph.number_of_edges(str(r))
            return rst

        print("Train enc graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.train_enc_graph.number_of_nodes('user'),
            self.train_enc_graph.number_of_nodes('movie'),
            _npairs(self.train_enc_graph)))
        print("Train dec graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.train_dec_graph.number_of_nodes('user'),
            self.train_dec_graph.number_of_nodes('movie'),
            self.train_dec_graph.number_of_edges()))
        print("Valid enc graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.valid_enc_graph.number_of_nodes('user'),
            self.valid_enc_graph.number_of_nodes('movie'),
            _npairs(self.valid_enc_graph)))
        print("Valid dec graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.valid_dec_graph.number_of_nodes('user'),
            self.valid_dec_graph.number_of_nodes('movie'),
            self.valid_dec_graph.number_of_edges()))
        print("Test enc graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.test_enc_graph.number_of_nodes('user'),
            self.test_enc_graph.number_of_nodes('movie'),
            _npairs(self.test_enc_graph)))
        print("Test dec graph: \t#user:{}\t#movie:{}\t#pairs:{}".format(
            self.test_dec_graph.number_of_nodes('user'),
            self.test_dec_graph.number_of_nodes('movie'),
            self.test_dec_graph.number_of_edges()))
コード例 #30
0
ファイル: wn18.py プロジェクト: classicsong/dgl-graphloader
 def download(self):
     tgz_path = os.path.join(self.raw_dir, 'wn18.tgz')
     download(self.url, path=tgz_path)
     extract_archive(tgz_path, self.raw_path)