def test_to_device(index_dtype):
    g1 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 1)]},
                         index_dtype=index_dtype)
    g1.nodes['user'].data['h1'] = F.copy_to(F.tensor([[0.], [1.]]), F.cpu())
    g1.nodes['user'].data['h2'] = F.copy_to(F.tensor([[3.], [4.]]), F.cpu())
    g1.edges['plays'].data['h1'] = F.copy_to(F.tensor([[2.], [3.]]), F.cpu())

    g2 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 0)]},
                         index_dtype=index_dtype)
    g2.nodes['user'].data['h1'] = F.copy_to(F.tensor([[1.], [2.]]), F.cpu())
    g2.nodes['user'].data['h2'] = F.copy_to(F.tensor([[4.], [5.]]), F.cpu())
    g2.edges['plays'].data['h1'] = F.copy_to(F.tensor([[0.], [1.]]), F.cpu())

    bg = dgl.batch_hetero([g1, g2])

    if F.is_cuda_available():
        bg1 = bg.to(F.cuda())
        assert bg1 is not None
        assert bg.batch_size == bg1.batch_size
        assert bg.batch_num_nodes('user') == bg1.batch_num_nodes('user')
        assert bg.batch_num_edges('plays') == bg1.batch_num_edges('plays')

    # set feature
    g1 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 1)]},
                         index_dtype=index_dtype)
    g2 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 0)]},
                         index_dtype=index_dtype)
    bg = dgl.batch_hetero([g1, g2])
    if F.is_cuda_available():
        bg1 = bg.to(F.cuda())
        bg1.nodes['user'].data['test'] = F.copy_to(F.tensor([0, 1, 2, 3]),
                                                   F.cuda())
        bg1.edata['test'] = F.copy_to(F.tensor([0, 1, 2, 3]), F.cuda())
def test_batching_hetero_and_batched_hetero_topology(index_dtype):
    """Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph."""
    g1 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): [(0, 0), (1, 0)]
    }, index_dtype=index_dtype)
    g2 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): [(0, 0), (1, 0)]
    }, index_dtype=index_dtype)
    bg1 = dgl.batch_hetero([g1, g2])
    g3 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1)],
        ('user', 'plays', 'game'): [(1, 0)]
    }, index_dtype=index_dtype)
    bg2 = dgl.batch_hetero([bg1, g3])
    assert bg2.ntypes == g3.ntypes
    assert bg2.etypes == g3.etypes
    assert bg2.canonical_etypes == g3.canonical_etypes
    assert bg2.batch_size == 3

    # Test number of nodes
    for ntype in bg2.ntypes:
        assert bg2.batch_num_nodes(ntype) == [
            g1.number_of_nodes(ntype), g2.number_of_nodes(ntype), g3.number_of_nodes(ntype)]
        assert bg2.number_of_nodes(ntype) == (
                g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype))

    # Test number of edges
    for etype in bg2.etypes:
        assert bg2.batch_num_edges(etype) == [
            g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)]
        assert bg2.number_of_edges(etype) == (
                g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype))

    for etype in bg2.canonical_etypes:
        assert bg2.batch_num_edges(etype) == [
            g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)]
        assert bg2.number_of_edges(etype) == (
                g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype))

    # Test relabeled nodes
    for ntype in bg2.ntypes:
        assert list(F.asnumpy(bg2.nodes(ntype))) == list(range(bg2.number_of_nodes(ntype)))

    # Test relabeled edges
    src, dst = bg2.all_edges(etype='follows')
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7]
    src, dst = bg2.all_edges(etype='plays')
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2]

    # Test unbatching graphs
    g4, g5, g6 = dgl.unbatch_hetero(bg2)
    check_equivalence_between_heterographs(g1, g4)
    check_equivalence_between_heterographs(g2, g5)
    check_equivalence_between_heterographs(g3, g6)
Example #3
0
    def train(g_, x):

        g = copy.deepcopy(g_)

        g_ = dgl.batch_hetero([g_ for _ in range(x.shape[0])])

        g = net(g, return_graph=True)

        g = unnorm(g)
        '''
        for term in ['bond', 'angle']:
            for param in ['k', 'eq']:
                g.nodes[term].data[param] = torch.exp(
                    g.nodes[term].data[param])
        '''

        g = dgl.batch_hetero([g for _ in range(x.shape[0])])

        g.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3])

        g_.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3])

        g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g)

        g = hgfp.mm.energy_in_heterograph.u(g)

        g_ = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g_)

        g_ = hgfp.mm.energy_in_heterograph.u(g_)

        u = torch.sum(
            torch.cat(
                [
                    g_.nodes['mol'].data['u' + term][:, None] for term in [
                        'bond',
                        'angle',  # 'torsion', 'one_four', 'nonbonded'# , '0'
                    ]
                ],
                dim=1),
            dim=1)

        u_hat = torch.sum(
            torch.cat(
                [
                    g.nodes['mol'].data['u' + term][:, None] for term in [
                        'bond',
                        'angle'  # , 'torsion', 'one_four', 'nonbonded'# , '0'
                    ]
                ],
                dim=1),
            dim=1)

        return u, u_hat
def test_batching_with_zero_nodes_edges(index_dtype):
    """Test the features of batched DGLHeteroGraphs"""
    g1 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): []
    }, index_dtype=index_dtype)
    g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])

    g2 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): [(0, 0), (1, 0)]
    }, index_dtype=index_dtype)
    g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g2.nodes['game'].data['h1'] = F.tensor([[0.]])
    g2.nodes['game'].data['h2'] = F.tensor([[1.]])
    g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

    bg = dgl.batch_hetero([g1, g2])

    assert F.allclose(bg.nodes['user'].data['h1'],
                      F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0))
    assert F.allclose(bg.nodes['user'].data['h2'],
                      F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0))
    assert F.allclose(bg.nodes['game'].data['h1'], g2.nodes['game'].data['h1'])
    assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2'])
    assert F.allclose(bg.edges['follows'].data['h1'],
                      F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
    assert F.allclose(bg.edges['plays'].data['h1'], g2.edges['plays'].data['h1'])

    # Test unbatching graphs
    g3, g4 = dgl.unbatch_hetero(bg)
    check_equivalence_between_heterographs(
        g1, g3,
        node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
        edge_attrs={('user', 'follows', 'user'): ['h1']})
    check_equivalence_between_heterographs(
        g2, g4,
        node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
        edge_attrs={('user', 'follows', 'user'): ['h1']})

    # Test graphs without edges
    g1 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(0, 4))
    g2 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(1, 5))
    g2.nodes['u'].data['x'] = F.tensor([1])
    dgl.batch_hetero([g1, g2])
Example #5
0
    def train(theta, g_, x):

        g = copy.deepcopy(g_)

        g.nodes['bond'].data['k'] = theta[:18]
        g.nodes['bond'].data['eq'] = theta[18:36]
        g.nodes['angle'].data['k'] = theta[36:69]
        g.nodes['angle'].data['eq'] = theta[69:]

        g = unnorm(g)

        g_ = dgl.batch_hetero([g_ for _ in range(x.shape[0])])

        g = dgl.batch_hetero([g for _ in range(x.shape[0])])

        g.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3])

        g_.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3])

        g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g)

        g_ = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g_)

        g = hgfp.mm.energy_in_heterograph.u(g)

        g_ = hgfp.mm.energy_in_heterograph.u(g_)

        u = torch.sum(
            torch.cat(
                [
                    g_.nodes['mol'].data['u' + term][:, None] for term in [
                        'bond',
                        'angle',  # 'torsion', 'one_four', 'nonbonded'# , '0'
                    ]
                ],
                dim=1),
            dim=1)

        u_hat = torch.sum(
            torch.cat(
                [
                    g.nodes['mol'].data['u' + term][:, None] for term in [
                        'bond',
                        'angle'  # , 'torsion', 'one_four', 'nonbonded'# , '0'
                    ]
                ],
                dim=1),
            dim=1)

        return u, u_hat, norm(g_), norm(g)
Example #6
0
    def batch(graphs):
        import dgl

        if all(isinstance(graph, esp.graphs.graph.Graph) for graph in graphs):
            return dgl.batch_hetero([graph.heterograph for graph in graphs])

        elif all(isinstance(graph, dgl.DGLGraph) for graph in graphs):
            return dgl.batch(graphs)

        elif all(isinstance(graph, dgl.DGLHeteroGraph) for graph in graphs):
            return dgl.batch_hetero(graphs)

        else:
            raise RuntimeError("Can only batch DGLGraph or DGLHeterograph,"
                               "now have %s" % type(graphs[0]))
def test_acnn():
    remove_dir('tmp1')
    remove_dir('tmp2')

    url = _get_dgl_url('dgllife/example_mols.tar.gz')
    local_path = 'tmp1/example_mols.tar.gz'
    download(url, path=local_path)
    extract_archive(local_path, 'tmp2')

    pocket_mol, pocket_coords = load_molecule(
        'tmp2/example_mols/example.pdb', remove_hs=True)
    ligand_mol, ligand_coords = load_molecule(
        'tmp2/example_mols/example.pdbqt', remove_hs=True)

    remove_dir('tmp1')
    remove_dir('tmp2')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g1 = ACNN_graph_construction_and_featurization(ligand_mol,
                                                   pocket_mol,
                                                   ligand_coords,
                                                   pocket_coords)

    model = ACNN()
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])

    model = ACNN(hidden_sizes=[1, 2],
                 weight_init_stddevs=[1, 1],
                 dropouts=[0.1, 0.],
                 features_to_use=torch.tensor([6., 8.]),
                 radial=[[12.0], [0.0, 2.0], [4.0]])
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])
Example #8
0
def test_pickling_batched_heterograph():
    # copied from test_heterograph.create_test_heterograph()
    plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1])))
    wishes_nx = nx.DiGraph()
    wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0)
    wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1)
    wishes_nx.add_edge('u0', 'g1', id=0)
    wishes_nx.add_edge('u2', 'g0', id=1)

    follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
    plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game')
    wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game')
    develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops',
                               'game')
    g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])
    g2 = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g])

    g.nodes['user'].data['u_h'] = F.randn((3, 4))
    g.nodes['game'].data['g_h'] = F.randn((2, 5))
    g.edges['plays'].data['p_h'] = F.randn((4, 6))
    g2.nodes['user'].data['u_h'] = F.randn((3, 4))
    g2.nodes['game'].data['g_h'] = F.randn((2, 5))
    g2.edges['plays'].data['p_h'] = F.randn((4, 6))

    bg = dgl.batch_hetero([g, g2])
    new_bg = _reconstruct_pickle(bg)
    test_utils.check_graph_equal(bg, new_bg)
Example #9
0
def test_pickling_batched_heterograph():
    # copied from test_heterograph.create_test_heterograph()
    g = dgl.heterograph({
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
        ('user', 'wishes', 'game'): ([0, 2], [1, 0]),
        ('developer', 'develops', 'game'): ([0, 1], [0, 1])
    })
    g2 = dgl.heterograph({
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
        ('user', 'wishes', 'game'): ([0, 2], [1, 0]),
        ('developer', 'develops', 'game'): ([0, 1], [0, 1])
    })

    g.nodes['user'].data['u_h'] = F.randn((3, 4))
    g.nodes['game'].data['g_h'] = F.randn((2, 5))
    g.edges['plays'].data['p_h'] = F.randn((4, 6))
    g2.nodes['user'].data['u_h'] = F.randn((3, 4))
    g2.nodes['game'].data['g_h'] = F.randn((2, 5))
    g2.edges['plays'].data['p_h'] = F.randn((4, 6))

    bg = dgl.batch_hetero([g, g2])
    new_bg = _reconstruct_pickle(bg)
    test_utils.check_graph_equal(bg, new_bg)
def test_batching_hetero_topology(index_dtype):
    """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
    g1 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'follows', 'developer'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
    }, index_dtype=index_dtype)
    g2 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'follows', 'developer'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
    }, index_dtype=index_dtype)
    bg = dgl.batch_hetero([g1, g2])

    assert bg.ntypes == g2.ntypes
    assert bg.etypes == g2.etypes
    assert bg.canonical_etypes == g2.canonical_etypes
    assert bg.batch_size == 2

    # Test number of nodes
    for ntype in bg.ntypes:
        assert bg.batch_num_nodes(ntype) == [
            g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)]
        assert bg.number_of_nodes(ntype) == (
                g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype))

    # Test number of edges
    assert bg.batch_num_edges('plays') == [
        g1.number_of_edges('plays'), g2.number_of_edges('plays')]
    assert bg.number_of_edges('plays') == (
        g1.number_of_edges('plays') + g2.number_of_edges('plays'))

    for etype in bg.canonical_etypes:
        assert bg.batch_num_edges(etype) == [
            g1.number_of_edges(etype), g2.number_of_edges(etype)]
        assert bg.number_of_edges(etype) == (
            g1.number_of_edges(etype) + g2.number_of_edges(etype))

    # Test relabeled nodes
    for ntype in bg.ntypes:
        assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype)))

    # Test relabeled edges
    src, dst = bg.all_edges(etype=('user', 'follows', 'user'))
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
    src, dst = bg.all_edges(etype=('user', 'follows', 'developer'))
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
    src, dst = bg.all_edges(etype='plays')
    assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]

    # Test unbatching graphs
    g3, g4 = dgl.unbatch_hetero(bg)
    check_equivalence_between_heterographs(g1, g3)
    check_equivalence_between_heterographs(g2, g4)
def test_batched_features(index_dtype):
    """Test the features of batched DGLHeteroGraphs"""
    g1 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): [(0, 0), (1, 0)]
    }, index_dtype=index_dtype)
    g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g1.nodes['game'].data['h1'] = F.tensor([[0.]])
    g1.nodes['game'].data['h2'] = F.tensor([[1.]])
    g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g1.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

    g2 = dgl.heterograph({
        ('user', 'follows', 'user'): [(0, 1), (1, 2)],
        ('user', 'plays', 'game'): [(0, 0), (1, 0)]
    }, index_dtype=index_dtype)
    g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g2.nodes['game'].data['h1'] = F.tensor([[0.]])
    g2.nodes['game'].data['h2'] = F.tensor([[1.]])
    g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

    bg = dgl.batch_hetero([g1, g2],
                          node_attrs=ALL,
                          edge_attrs={
                              ('user', 'follows', 'user'): 'h1',
                              ('user', 'plays', 'game'): None
                          })

    assert F.allclose(bg.nodes['user'].data['h1'],
                      F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0))
    assert F.allclose(bg.nodes['user'].data['h2'],
                      F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0))
    assert F.allclose(bg.nodes['game'].data['h1'],
                      F.cat([g1.nodes['game'].data['h1'], g2.nodes['game'].data['h1']], dim=0))
    assert F.allclose(bg.nodes['game'].data['h2'],
                      F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0))
    assert F.allclose(bg.edges['follows'].data['h1'],
                      F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
    assert 'h2' not in bg.edges['follows'].data.keys()
    assert 'h1' not in bg.edges['plays'].data.keys()

    # Test unbatching graphs
    g3, g4 = dgl.unbatch_hetero(bg)
    check_equivalence_between_heterographs(
        g1, g3,
        node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
        edge_attrs={('user', 'follows', 'user'): ['h1']})
    check_equivalence_between_heterographs(
        g2, g4,
        node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
        edge_attrs={('user', 'follows', 'user'): ['h1']})
def collate(data):
    indices, ligand_mols, protein_mols, graphs, labels = map(list, zip(*data))
    bg = dgl.batch_hetero(graphs)
    for nty in bg.ntypes:
        bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty)
    for ety in bg.canonical_etypes:
        bg.set_e_initializer(dgl.init.zero_initializer, etype=ety)
    labels = torch.stack(labels, dim=0)

    return indices, ligand_mols, protein_mols, bg, labels
Example #13
0
def train(path, config, batch_size=16, learning_rate=1e-5, n_epoches=50):

    time_str = strftime("%Y-%m-%d_%H_%M_%S", localtime())
    os.mkdir(time_str)

    gs_, _ = dgl.data.utils.load_graphs(path)

    gs = []
    for g_ in gs_:
        g = hgfp.heterograph.from_graph(g_)
        g.nodes['atom'].data['q'] = g_.ndata['am1_charge'] + g_.ndata['bcc_charge']
        gs.append(g)

    gs_batched = []

    while True:
        try:
            gs_batched.append(dgl.batch_hetero([gs.pop(0) for _ in range(batch_size)]))
        except:
            break

    ds_tr, ds_te, ds_vl = hgfp.data.utils.split(gs_batched, 1, 1)

    net = gnn_charge.models.Net(config)
    eq = gnn_charge.eq.ChargeEquilibrium()

    opt = torch.optim.Adam(
        list(net.parameters()) + list(eq.parameters()),
        learning_rate)

    loss_fn = torch.nn.functional.mse_loss

    losses = np.array([0.])
    rmse_vl = []
    r2_vl = []
    rmse_tr = []
    r2_tr = []

    for idx_epoch in range(n_epoches):
        for g in ds_tr:
            g_hat = eq(net(g))
            q = g.nodes['atom'].data['q']
            q_hat = g_hat.nodes['atom'].data['q_hat']
            loss = loss_fn(q, q_hat)

            opt.zero_grad()
            loss.backward()
            opt.step()

        net.eval()

        u_tr = np.array([0.])
        u_hat_tr = np.array([0.])

        u_vl = np.array([0.])
        u_hat_vl = np.array([0.])

        with torch.no_grad():
            for g in ds_tr:
                u_hat = eq(net(g)).nodes['atom'].data['q_hat']
                u = g.nodes['atom'].data['q']
                u_tr = np.concatenate([u_tr, u.detach().numpy()], axis=0)
                u_hat_tr = np.concatenate([u_hat_tr, u_hat.detach().numpy()], axis=0)

            for g in ds_vl:
                u_hat = eq(net(g)).nodes['atom'].data['q_hat']
                u = g.nodes['atom'].data['q']
                u_vl = np.concatenate([u_vl, u.detach().numpy()], axis=0)
                u_hat_vl = np.concatenate([u_hat_vl, u_hat.detach().numpy()], axis=0)

        u_tr = u_tr[1:]
        u_vl = u_vl[1:]
        u_hat_tr = u_hat_tr[1:]
        u_hat_vl = u_hat_vl[1:]

        rmse_tr.append(
            np.sqrt(
                mean_squared_error(
                    u_tr,
                    u_hat_tr)))

        rmse_vl.append(
            np.sqrt(
                mean_squared_error(
                    u_vl,
                    u_hat_vl)))

        r2_tr.append(
            r2_score(
                u_tr,
                u_hat_tr))

        r2_vl.append(
            r2_score(
                u_vl,
                u_hat_vl))

        plt.style.use('fivethirtyeight')
        plt.figure()
        plt.plot(rmse_tr[1:], label=r'$RMSE_\mathtt{TRAIN}$')
        plt.plot(rmse_vl[1:], label=r'$RMSE_\mathtt{VALIDATION}$')
        plt.legend()
        plt.tight_layout()
        plt.savefig(time_str + '/RMSE.jpg')
        plt.close()
        plt.figure()
        plt.plot(r2_tr[1:], label=r'$R^2_\mathtt{TRAIN}$')
        plt.plot(r2_vl[1:], label=r'$R^2_\mathtt{VALIDATION}$')
        plt.legend()
        plt.tight_layout()
        plt.savefig(time_str + '/R2.jpg')
        plt.close()
        plt.figure()
        plt.plot(losses[10:])
        plt.title('loss')
        plt.tight_layout()
        plt.savefig(time_str + '/loss.jpg')
        plt.close()
Example #14
0
# net.load_state_dict(torch.load('/data/chodera/wangyq/hgfp_scripts/gcn_param/2020-03-30_11_41_19/model'))

net.load_state_dict(torch.load('model_multi.ds'))

mean_and_std_dict = torch.load('/data/chodera/wangyq/hgfp_scripts/gcn_param/2020-03-30_11_41_19/norm_dict')

norm, unnorm = hgfp.data.utils.get_norm_fn(mean_and_std_dict)

loss_fn = torch.nn.functional.mse_loss

for g_, u in ds:
    g_ = dgl.unbatch_hetero(g_)

    idxs = random.choices(list(range(len(g_))), k=64)
    
    g_ = dgl.batch_hetero(
            [g_[idx] for idx in idxs])
    
    g = copy.deepcopy(g_)
    
    u = u[idxs]

    g = net(g, return_graph=True)

    g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(
        g)

    g = hgfp.mm.energy_in_heterograph.u(g)

    g = unnorm(g)

    u = torch.sum(
Example #15
0
    def forward(self, **params):
        '''
            words: [batch_size, max_length]
            src_lengths: [batchs_size]
            mask: [batch_size, max_length]
            entity_type: [batch_size, max_length]
            entity_id: [batch_size, max_length]
            mention_id: [batch_size, max_length]
            distance: [batch_size, max_length]
            entity2mention_table: list of [local_entity_num, local_mention_num]
            graphs: list of DGLHeteroGraph
            h_t_pairs: [batch_size, h_t_limit, 2]
        '''
        src = self.word_emb(params['words'])
        mask = params['mask']
        bsz, slen, _ = src.size()

        if self.config.use_entity_type:
            src = torch.cat(
                [src, self.entity_type_emb(params['entity_type'])], dim=-1)

        if self.config.use_entity_id:
            src = torch.cat([src, self.entity_id_emb(params['entity_id'])],
                            dim=-1)

        # src: [batch_size, slen, encoder_input_size]
        # src_lengths: [batchs_size]

        encoder_outputs, (output_h_t,
                          _) = self.encoder(src, params['src_lengths'])
        encoder_outputs[mask == 0] = 0
        # encoder_outputs: [batch_size, slen, 2*encoder_hid_size]
        # output_h_t: [batch_size, 2*encoder_hid_size]

        graphs = params['graphs']

        mention_id = params['mention_id']
        features = None

        for i in range(len(graphs)):
            encoder_output = encoder_outputs[i]  # [slen, 2*encoder_hid_size]
            mention_num = torch.max(mention_id[i])
            mention_index = get_cuda(
                (torch.arange(mention_num) + 1).unsqueeze(1).expand(
                    -1, slen))  # [mention_num, slen]
            mentions = mention_id[i].unsqueeze(0).expand(
                mention_num, -1)  # [mention_num, slen]
            select_metrix = (
                mention_index == mentions).float()  # [mention_num, slen]
            # average word -> mention
            word_total_numbers = torch.sum(select_metrix,
                                           dim=-1).unsqueeze(-1).expand(
                                               -1, slen)  # [mention_num, slen]
            select_metrix = torch.where(word_total_numbers > 0,
                                        select_metrix / word_total_numbers,
                                        select_metrix)
            x = torch.mm(select_metrix,
                         encoder_output)  # [mention_num, 2*encoder_hid_size]

            x = torch.cat((output_h_t[i].unsqueeze(0), x), dim=0)

            if features is None:
                features = x
            else:
                features = torch.cat((features, x), dim=0)

        graph_big = dgl.batch_hetero(graphs)
        output_features = [features]

        for GCN_layer in self.GCN_layers:
            features = GCN_layer(
                graph_big,
                {"node": features})["node"]  # [total_mention_nums, gcn_dim]
            output_features.append(features)

        output_feature = torch.cat(output_features, dim=-1)

        graphs = dgl.unbatch_hetero(graph_big)

        # mention -> entity
        entity2mention_table = params[
            'entity2mention_table']  # list of [entity_num, mention_num]
        entity_num = torch.max(params['entity_id'])
        entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size))
        global_info = get_cuda(torch.Tensor(bsz, self.bank_size))

        cur_idx = 0
        entity_graph_feature = None
        for i in range(len(graphs)):
            # average mention -> entity
            select_metrix = entity2mention_table[i].float(
            )  # [local_entity_num, mention_num]
            select_metrix[0][0] = 1
            mention_nums = torch.sum(select_metrix,
                                     dim=-1).unsqueeze(-1).expand(
                                         -1, select_metrix.size(1))
            select_metrix = torch.where(mention_nums > 0,
                                        select_metrix / mention_nums,
                                        select_metrix)
            node_num = graphs[i].number_of_nodes('node')
            entity_representation = torch.mm(
                select_metrix, output_feature[cur_idx:cur_idx + node_num])
            entity_bank[i, :select_metrix.size(0) -
                        1] = entity_representation[1:]
            global_info[i] = output_feature[cur_idx]
            cur_idx += node_num

            if entity_graph_feature is None:
                entity_graph_feature = entity_representation[
                    1:, -self.config.gcn_dim:]
            else:
                entity_graph_feature = torch.cat(
                    (entity_graph_feature,
                     entity_representation[1:, -self.config.gcn_dim:]),
                    dim=0)

        h_t_pairs = params['h_t_pairs']
        h_t_pairs = h_t_pairs + (h_t_pairs
                                 == 0).long() - 1  # [batch_size, h_t_limit, 2]
        h_t_limit = h_t_pairs.size(1)

        # [batch_size, h_t_limit, bank_size]
        h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand(
            -1, -1, self.bank_size)
        t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand(
            -1, -1, self.bank_size)

        # [batch_size, h_t_limit, bank_size]
        h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index)
        t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index)

        global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1)

        entity_graphs = params['entity_graphs']
        entity_graph_big = dgl.batch(entity_graphs)
        self.edge_layer(entity_graph_big, entity_graph_feature)
        entity_graphs = dgl.unbatch(entity_graph_big)
        path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4)))
        relation_mask = params['relation_mask']
        path_table = params['path_table']
        for i in range(len(entity_graphs)):
            path_t = path_table[i]
            for j in range(h_t_limit):
                if relation_mask is not None and relation_mask[i,
                                                               j].item() == 0:
                    break

                h = h_t_pairs[i, j, 0].item()
                t = h_t_pairs[i, j, 1].item()
                # for evaluate
                if relation_mask is None and h == 0 and t == 0:
                    continue

                if (h + 1, t + 1) in path_t:
                    v = [val - 1 for val in path_t[(h + 1, t + 1)]]
                elif (t + 1, h + 1) in path_t:
                    v = [val - 1 for val in path_t[(t + 1, h + 1)]]
                else:
                    print(h, t, v)
                    print(entity_graphs[i].all_edges())
                    print(h_t_pairs)
                    print(relation_mask)
                    assert 1 == 2

                middle_node_num = len(v)

                if middle_node_num == 0:
                    continue

                # forward
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    [h for _ in range(middle_node_num)], v))
                forward_first = torch.index_select(entity_graphs[i].edata['h'],
                                                   dim=0,
                                                   index=edge_ids)
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    v, [t for _ in range(middle_node_num)]))
                forward_second = torch.index_select(
                    entity_graphs[i].edata['h'], dim=0, index=edge_ids)

                # backward
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    [t for _ in range(middle_node_num)], v))
                backward_first = torch.index_select(
                    entity_graphs[i].edata['h'], dim=0, index=edge_ids)
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    v, [h for _ in range(middle_node_num)]))
                backward_second = torch.index_select(
                    entity_graphs[i].edata['h'], dim=0, index=edge_ids)

                tmp_path_info = torch.cat((forward_first, forward_second,
                                           backward_first, backward_second),
                                          dim=-1)
                _, attn_value = self.attention(
                    torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1),
                    tmp_path_info)
                path_info[i, j] = attn_value

            entity_graphs[i].edata.pop('h')

        path_info = self.dropout(
            self.activation(self.path_info_mapping(path_info)))

        predictions = self.predict(
            torch.cat((h_entity, t_entity, torch.abs(h_entity - t_entity),
                       torch.mul(h_entity, t_entity), global_info, path_info),
                      dim=-1))
        return predictions
Example #16
0
def collate_movielens(data):
    g_list, label_list = map(list, zip(*data))
    g = dgl.batch_hetero(g_list)
    g_label = th.stack(label_list)
    return g, g_label