def test_to_device(index_dtype): g1 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 1)]}, index_dtype=index_dtype) g1.nodes['user'].data['h1'] = F.copy_to(F.tensor([[0.], [1.]]), F.cpu()) g1.nodes['user'].data['h2'] = F.copy_to(F.tensor([[3.], [4.]]), F.cpu()) g1.edges['plays'].data['h1'] = F.copy_to(F.tensor([[2.], [3.]]), F.cpu()) g2 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 0)]}, index_dtype=index_dtype) g2.nodes['user'].data['h1'] = F.copy_to(F.tensor([[1.], [2.]]), F.cpu()) g2.nodes['user'].data['h2'] = F.copy_to(F.tensor([[4.], [5.]]), F.cpu()) g2.edges['plays'].data['h1'] = F.copy_to(F.tensor([[0.], [1.]]), F.cpu()) bg = dgl.batch_hetero([g1, g2]) if F.is_cuda_available(): bg1 = bg.to(F.cuda()) assert bg1 is not None assert bg.batch_size == bg1.batch_size assert bg.batch_num_nodes('user') == bg1.batch_num_nodes('user') assert bg.batch_num_edges('plays') == bg1.batch_num_edges('plays') # set feature g1 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 1)]}, index_dtype=index_dtype) g2 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 0)]}, index_dtype=index_dtype) bg = dgl.batch_hetero([g1, g2]) if F.is_cuda_available(): bg1 = bg.to(F.cuda()) bg1.nodes['user'].data['test'] = F.copy_to(F.tensor([0, 1, 2, 3]), F.cuda()) bg1.edata['test'] = F.copy_to(F.tensor([0, 1, 2, 3]), F.cuda())
def test_batching_hetero_and_batched_hetero_topology(index_dtype): """Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph.""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) bg1 = dgl.batch_hetero([g1, g2]) g3 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1)], ('user', 'plays', 'game'): [(1, 0)] }, index_dtype=index_dtype) bg2 = dgl.batch_hetero([bg1, g3]) assert bg2.ntypes == g3.ntypes assert bg2.etypes == g3.etypes assert bg2.canonical_etypes == g3.canonical_etypes assert bg2.batch_size == 3 # Test number of nodes for ntype in bg2.ntypes: assert bg2.batch_num_nodes(ntype) == [ g1.number_of_nodes(ntype), g2.number_of_nodes(ntype), g3.number_of_nodes(ntype)] assert bg2.number_of_nodes(ntype) == ( g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype)) # Test number of edges for etype in bg2.etypes: assert bg2.batch_num_edges(etype) == [ g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)] assert bg2.number_of_edges(etype) == ( g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype)) for etype in bg2.canonical_etypes: assert bg2.batch_num_edges(etype) == [ g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)] assert bg2.number_of_edges(etype) == ( g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype)) # Test relabeled nodes for ntype in bg2.ntypes: assert list(F.asnumpy(bg2.nodes(ntype))) == list(range(bg2.number_of_nodes(ntype))) # Test relabeled edges src, dst = bg2.all_edges(etype='follows') assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6] assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7] src, dst = bg2.all_edges(etype='plays') assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7] assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2] # Test unbatching graphs g4, g5, g6 = dgl.unbatch_hetero(bg2) check_equivalence_between_heterographs(g1, g4) check_equivalence_between_heterographs(g2, g5) check_equivalence_between_heterographs(g3, g6)
def train(g_, x): g = copy.deepcopy(g_) g_ = dgl.batch_hetero([g_ for _ in range(x.shape[0])]) g = net(g, return_graph=True) g = unnorm(g) ''' for term in ['bond', 'angle']: for param in ['k', 'eq']: g.nodes[term].data[param] = torch.exp( g.nodes[term].data[param]) ''' g = dgl.batch_hetero([g for _ in range(x.shape[0])]) g.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3]) g_.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3]) g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g) g = hgfp.mm.energy_in_heterograph.u(g) g_ = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g_) g_ = hgfp.mm.energy_in_heterograph.u(g_) u = torch.sum( torch.cat( [ g_.nodes['mol'].data['u' + term][:, None] for term in [ 'bond', 'angle', # 'torsion', 'one_four', 'nonbonded'# , '0' ] ], dim=1), dim=1) u_hat = torch.sum( torch.cat( [ g.nodes['mol'].data['u' + term][:, None] for term in [ 'bond', 'angle' # , 'torsion', 'one_four', 'nonbonded'# , '0' ] ], dim=1), dim=1) return u, u_hat
def test_batching_with_zero_nodes_edges(index_dtype): """Test the features of batched DGLHeteroGraphs""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [] }, index_dtype=index_dtype) g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g2.nodes['game'].data['h1'] = F.tensor([[0.]]) g2.nodes['game'].data['h2'] = F.tensor([[1.]]) g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) bg = dgl.batch_hetero([g1, g2]) assert F.allclose(bg.nodes['user'].data['h1'], F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0)) assert F.allclose(bg.nodes['user'].data['h2'], F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0)) assert F.allclose(bg.nodes['game'].data['h1'], g2.nodes['game'].data['h1']) assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2']) assert F.allclose(bg.edges['follows'].data['h1'], F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0)) assert F.allclose(bg.edges['plays'].data['h1'], g2.edges['plays'].data['h1']) # Test unbatching graphs g3, g4 = dgl.unbatch_hetero(bg) check_equivalence_between_heterographs( g1, g3, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']}) check_equivalence_between_heterographs( g2, g4, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']}) # Test graphs without edges g1 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(0, 4)) g2 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(1, 5)) g2.nodes['u'].data['x'] = F.tensor([1]) dgl.batch_hetero([g1, g2])
def train(theta, g_, x): g = copy.deepcopy(g_) g.nodes['bond'].data['k'] = theta[:18] g.nodes['bond'].data['eq'] = theta[18:36] g.nodes['angle'].data['k'] = theta[36:69] g.nodes['angle'].data['eq'] = theta[69:] g = unnorm(g) g_ = dgl.batch_hetero([g_ for _ in range(x.shape[0])]) g = dgl.batch_hetero([g for _ in range(x.shape[0])]) g.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3]) g_.nodes['atom'].data['xyz'] = torch.reshape(x, [-1, 3]) g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g) g_ = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g_) g = hgfp.mm.energy_in_heterograph.u(g) g_ = hgfp.mm.energy_in_heterograph.u(g_) u = torch.sum( torch.cat( [ g_.nodes['mol'].data['u' + term][:, None] for term in [ 'bond', 'angle', # 'torsion', 'one_four', 'nonbonded'# , '0' ] ], dim=1), dim=1) u_hat = torch.sum( torch.cat( [ g.nodes['mol'].data['u' + term][:, None] for term in [ 'bond', 'angle' # , 'torsion', 'one_four', 'nonbonded'# , '0' ] ], dim=1), dim=1) return u, u_hat, norm(g_), norm(g)
def batch(graphs): import dgl if all(isinstance(graph, esp.graphs.graph.Graph) for graph in graphs): return dgl.batch_hetero([graph.heterograph for graph in graphs]) elif all(isinstance(graph, dgl.DGLGraph) for graph in graphs): return dgl.batch(graphs) elif all(isinstance(graph, dgl.DGLHeteroGraph) for graph in graphs): return dgl.batch_hetero(graphs) else: raise RuntimeError("Can only batch DGLGraph or DGLHeterograph," "now have %s" % type(graphs[0]))
def test_acnn(): remove_dir('tmp1') remove_dir('tmp2') url = _get_dgl_url('dgllife/example_mols.tar.gz') local_path = 'tmp1/example_mols.tar.gz' download(url, path=local_path) extract_archive(local_path, 'tmp2') pocket_mol, pocket_coords = load_molecule( 'tmp2/example_mols/example.pdb', remove_hs=True) ligand_mol, ligand_coords = load_molecule( 'tmp2/example_mols/example.pdbqt', remove_hs=True) remove_dir('tmp1') remove_dir('tmp2') if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') g1 = ACNN_graph_construction_and_featurization(ligand_mol, pocket_mol, ligand_coords, pocket_coords) model = ACNN() model.to(device) g1.to(device) assert model(g1).shape == torch.Size([1, 1]) bg = dgl.batch_hetero([g1, g1]) bg.to(device) assert model(bg).shape == torch.Size([2, 1]) model = ACNN(hidden_sizes=[1, 2], weight_init_stddevs=[1, 1], dropouts=[0.1, 0.], features_to_use=torch.tensor([6., 8.]), radial=[[12.0], [0.0, 2.0], [4.0]]) model.to(device) g1.to(device) assert model(g1).shape == torch.Size([1, 1]) bg = dgl.batch_hetero([g1, g1]) bg.to(device) assert model(bg).shape == torch.Size([2, 1])
def test_pickling_batched_heterograph(): # copied from test_heterograph.create_test_heterograph() plays_spmat = ssp.coo_matrix(([1, 1, 1, 1], ([0, 1, 2, 1], [0, 0, 1, 1]))) wishes_nx = nx.DiGraph() wishes_nx.add_nodes_from(['u0', 'u1', 'u2'], bipartite=0) wishes_nx.add_nodes_from(['g0', 'g1'], bipartite=1) wishes_nx.add_edge('u0', 'g1', id=0) wishes_nx.add_edge('u2', 'g0', id=1) follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') plays_g = dgl.bipartite(plays_spmat, 'user', 'plays', 'game') wishes_g = dgl.bipartite(wishes_nx, 'user', 'wishes', 'game') develops_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game') g = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g]) g2 = dgl.hetero_from_relations([follows_g, plays_g, wishes_g, develops_g]) g.nodes['user'].data['u_h'] = F.randn((3, 4)) g.nodes['game'].data['g_h'] = F.randn((2, 5)) g.edges['plays'].data['p_h'] = F.randn((4, 6)) g2.nodes['user'].data['u_h'] = F.randn((3, 4)) g2.nodes['game'].data['g_h'] = F.randn((2, 5)) g2.edges['plays'].data['p_h'] = F.randn((4, 6)) bg = dgl.batch_hetero([g, g2]) new_bg = _reconstruct_pickle(bg) test_utils.check_graph_equal(bg, new_bg)
def test_pickling_batched_heterograph(): # copied from test_heterograph.create_test_heterograph() g = dgl.heterograph({ ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]), ('user', 'wishes', 'game'): ([0, 2], [1, 0]), ('developer', 'develops', 'game'): ([0, 1], [0, 1]) }) g2 = dgl.heterograph({ ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]), ('user', 'wishes', 'game'): ([0, 2], [1, 0]), ('developer', 'develops', 'game'): ([0, 1], [0, 1]) }) g.nodes['user'].data['u_h'] = F.randn((3, 4)) g.nodes['game'].data['g_h'] = F.randn((2, 5)) g.edges['plays'].data['p_h'] = F.randn((4, 6)) g2.nodes['user'].data['u_h'] = F.randn((3, 4)) g2.nodes['game'].data['g_h'] = F.randn((2, 5)) g2.edges['plays'].data['p_h'] = F.randn((4, 6)) bg = dgl.batch_hetero([g, g2]) new_bg = _reconstruct_pickle(bg) test_utils.check_graph_equal(bg, new_bg)
def test_batching_hetero_topology(index_dtype): """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'developer'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)] }, index_dtype=index_dtype) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'developer'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)] }, index_dtype=index_dtype) bg = dgl.batch_hetero([g1, g2]) assert bg.ntypes == g2.ntypes assert bg.etypes == g2.etypes assert bg.canonical_etypes == g2.canonical_etypes assert bg.batch_size == 2 # Test number of nodes for ntype in bg.ntypes: assert bg.batch_num_nodes(ntype) == [ g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)] assert bg.number_of_nodes(ntype) == ( g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype)) # Test number of edges assert bg.batch_num_edges('plays') == [ g1.number_of_edges('plays'), g2.number_of_edges('plays')] assert bg.number_of_edges('plays') == ( g1.number_of_edges('plays') + g2.number_of_edges('plays')) for etype in bg.canonical_etypes: assert bg.batch_num_edges(etype) == [ g1.number_of_edges(etype), g2.number_of_edges(etype)] assert bg.number_of_edges(etype) == ( g1.number_of_edges(etype) + g2.number_of_edges(etype)) # Test relabeled nodes for ntype in bg.ntypes: assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype))) # Test relabeled edges src, dst = bg.all_edges(etype=('user', 'follows', 'user')) assert list(F.asnumpy(src)) == [0, 1, 4, 5] assert list(F.asnumpy(dst)) == [1, 2, 5, 6] src, dst = bg.all_edges(etype=('user', 'follows', 'developer')) assert list(F.asnumpy(src)) == [0, 1, 4, 5] assert list(F.asnumpy(dst)) == [1, 2, 4, 5] src, dst = bg.all_edges(etype='plays') assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6] assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3] # Test unbatching graphs g3, g4 = dgl.unbatch_hetero(bg) check_equivalence_between_heterographs(g1, g3) check_equivalence_between_heterographs(g2, g4)
def test_batched_features(index_dtype): """Test the features of batched DGLHeteroGraphs""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g1.nodes['game'].data['h1'] = F.tensor([[0.]]) g1.nodes['game'].data['h2'] = F.tensor([[1.]]) g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g1.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g2.nodes['game'].data['h1'] = F.tensor([[0.]]) g2.nodes['game'].data['h2'] = F.tensor([[1.]]) g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) bg = dgl.batch_hetero([g1, g2], node_attrs=ALL, edge_attrs={ ('user', 'follows', 'user'): 'h1', ('user', 'plays', 'game'): None }) assert F.allclose(bg.nodes['user'].data['h1'], F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0)) assert F.allclose(bg.nodes['user'].data['h2'], F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0)) assert F.allclose(bg.nodes['game'].data['h1'], F.cat([g1.nodes['game'].data['h1'], g2.nodes['game'].data['h1']], dim=0)) assert F.allclose(bg.nodes['game'].data['h2'], F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0)) assert F.allclose(bg.edges['follows'].data['h1'], F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0)) assert 'h2' not in bg.edges['follows'].data.keys() assert 'h1' not in bg.edges['plays'].data.keys() # Test unbatching graphs g3, g4 = dgl.unbatch_hetero(bg) check_equivalence_between_heterographs( g1, g3, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']}) check_equivalence_between_heterographs( g2, g4, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']})
def collate(data): indices, ligand_mols, protein_mols, graphs, labels = map(list, zip(*data)) bg = dgl.batch_hetero(graphs) for nty in bg.ntypes: bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty) for ety in bg.canonical_etypes: bg.set_e_initializer(dgl.init.zero_initializer, etype=ety) labels = torch.stack(labels, dim=0) return indices, ligand_mols, protein_mols, bg, labels
def train(path, config, batch_size=16, learning_rate=1e-5, n_epoches=50): time_str = strftime("%Y-%m-%d_%H_%M_%S", localtime()) os.mkdir(time_str) gs_, _ = dgl.data.utils.load_graphs(path) gs = [] for g_ in gs_: g = hgfp.heterograph.from_graph(g_) g.nodes['atom'].data['q'] = g_.ndata['am1_charge'] + g_.ndata['bcc_charge'] gs.append(g) gs_batched = [] while True: try: gs_batched.append(dgl.batch_hetero([gs.pop(0) for _ in range(batch_size)])) except: break ds_tr, ds_te, ds_vl = hgfp.data.utils.split(gs_batched, 1, 1) net = gnn_charge.models.Net(config) eq = gnn_charge.eq.ChargeEquilibrium() opt = torch.optim.Adam( list(net.parameters()) + list(eq.parameters()), learning_rate) loss_fn = torch.nn.functional.mse_loss losses = np.array([0.]) rmse_vl = [] r2_vl = [] rmse_tr = [] r2_tr = [] for idx_epoch in range(n_epoches): for g in ds_tr: g_hat = eq(net(g)) q = g.nodes['atom'].data['q'] q_hat = g_hat.nodes['atom'].data['q_hat'] loss = loss_fn(q, q_hat) opt.zero_grad() loss.backward() opt.step() net.eval() u_tr = np.array([0.]) u_hat_tr = np.array([0.]) u_vl = np.array([0.]) u_hat_vl = np.array([0.]) with torch.no_grad(): for g in ds_tr: u_hat = eq(net(g)).nodes['atom'].data['q_hat'] u = g.nodes['atom'].data['q'] u_tr = np.concatenate([u_tr, u.detach().numpy()], axis=0) u_hat_tr = np.concatenate([u_hat_tr, u_hat.detach().numpy()], axis=0) for g in ds_vl: u_hat = eq(net(g)).nodes['atom'].data['q_hat'] u = g.nodes['atom'].data['q'] u_vl = np.concatenate([u_vl, u.detach().numpy()], axis=0) u_hat_vl = np.concatenate([u_hat_vl, u_hat.detach().numpy()], axis=0) u_tr = u_tr[1:] u_vl = u_vl[1:] u_hat_tr = u_hat_tr[1:] u_hat_vl = u_hat_vl[1:] rmse_tr.append( np.sqrt( mean_squared_error( u_tr, u_hat_tr))) rmse_vl.append( np.sqrt( mean_squared_error( u_vl, u_hat_vl))) r2_tr.append( r2_score( u_tr, u_hat_tr)) r2_vl.append( r2_score( u_vl, u_hat_vl)) plt.style.use('fivethirtyeight') plt.figure() plt.plot(rmse_tr[1:], label=r'$RMSE_\mathtt{TRAIN}$') plt.plot(rmse_vl[1:], label=r'$RMSE_\mathtt{VALIDATION}$') plt.legend() plt.tight_layout() plt.savefig(time_str + '/RMSE.jpg') plt.close() plt.figure() plt.plot(r2_tr[1:], label=r'$R^2_\mathtt{TRAIN}$') plt.plot(r2_vl[1:], label=r'$R^2_\mathtt{VALIDATION}$') plt.legend() plt.tight_layout() plt.savefig(time_str + '/R2.jpg') plt.close() plt.figure() plt.plot(losses[10:]) plt.title('loss') plt.tight_layout() plt.savefig(time_str + '/loss.jpg') plt.close()
# net.load_state_dict(torch.load('/data/chodera/wangyq/hgfp_scripts/gcn_param/2020-03-30_11_41_19/model')) net.load_state_dict(torch.load('model_multi.ds')) mean_and_std_dict = torch.load('/data/chodera/wangyq/hgfp_scripts/gcn_param/2020-03-30_11_41_19/norm_dict') norm, unnorm = hgfp.data.utils.get_norm_fn(mean_and_std_dict) loss_fn = torch.nn.functional.mse_loss for g_, u in ds: g_ = dgl.unbatch_hetero(g_) idxs = random.choices(list(range(len(g_))), k=64) g_ = dgl.batch_hetero( [g_[idx] for idx in idxs]) g = copy.deepcopy(g_) u = u[idxs] g = net(g, return_graph=True) g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz( g) g = hgfp.mm.energy_in_heterograph.u(g) g = unnorm(g) u = torch.sum(
def forward(self, **params): ''' words: [batch_size, max_length] src_lengths: [batchs_size] mask: [batch_size, max_length] entity_type: [batch_size, max_length] entity_id: [batch_size, max_length] mention_id: [batch_size, max_length] distance: [batch_size, max_length] entity2mention_table: list of [local_entity_num, local_mention_num] graphs: list of DGLHeteroGraph h_t_pairs: [batch_size, h_t_limit, 2] ''' src = self.word_emb(params['words']) mask = params['mask'] bsz, slen, _ = src.size() if self.config.use_entity_type: src = torch.cat( [src, self.entity_type_emb(params['entity_type'])], dim=-1) if self.config.use_entity_id: src = torch.cat([src, self.entity_id_emb(params['entity_id'])], dim=-1) # src: [batch_size, slen, encoder_input_size] # src_lengths: [batchs_size] encoder_outputs, (output_h_t, _) = self.encoder(src, params['src_lengths']) encoder_outputs[mask == 0] = 0 # encoder_outputs: [batch_size, slen, 2*encoder_hid_size] # output_h_t: [batch_size, 2*encoder_hid_size] graphs = params['graphs'] mention_id = params['mention_id'] features = None for i in range(len(graphs)): encoder_output = encoder_outputs[i] # [slen, 2*encoder_hid_size] mention_num = torch.max(mention_id[i]) mention_index = get_cuda( (torch.arange(mention_num) + 1).unsqueeze(1).expand( -1, slen)) # [mention_num, slen] mentions = mention_id[i].unsqueeze(0).expand( mention_num, -1) # [mention_num, slen] select_metrix = ( mention_index == mentions).float() # [mention_num, slen] # average word -> mention word_total_numbers = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand( -1, slen) # [mention_num, slen] select_metrix = torch.where(word_total_numbers > 0, select_metrix / word_total_numbers, select_metrix) x = torch.mm(select_metrix, encoder_output) # [mention_num, 2*encoder_hid_size] x = torch.cat((output_h_t[i].unsqueeze(0), x), dim=0) if features is None: features = x else: features = torch.cat((features, x), dim=0) graph_big = dgl.batch_hetero(graphs) output_features = [features] for GCN_layer in self.GCN_layers: features = GCN_layer( graph_big, {"node": features})["node"] # [total_mention_nums, gcn_dim] output_features.append(features) output_feature = torch.cat(output_features, dim=-1) graphs = dgl.unbatch_hetero(graph_big) # mention -> entity entity2mention_table = params[ 'entity2mention_table'] # list of [entity_num, mention_num] entity_num = torch.max(params['entity_id']) entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size)) global_info = get_cuda(torch.Tensor(bsz, self.bank_size)) cur_idx = 0 entity_graph_feature = None for i in range(len(graphs)): # average mention -> entity select_metrix = entity2mention_table[i].float( ) # [local_entity_num, mention_num] select_metrix[0][0] = 1 mention_nums = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand( -1, select_metrix.size(1)) select_metrix = torch.where(mention_nums > 0, select_metrix / mention_nums, select_metrix) node_num = graphs[i].number_of_nodes('node') entity_representation = torch.mm( select_metrix, output_feature[cur_idx:cur_idx + node_num]) entity_bank[i, :select_metrix.size(0) - 1] = entity_representation[1:] global_info[i] = output_feature[cur_idx] cur_idx += node_num if entity_graph_feature is None: entity_graph_feature = entity_representation[ 1:, -self.config.gcn_dim:] else: entity_graph_feature = torch.cat( (entity_graph_feature, entity_representation[1:, -self.config.gcn_dim:]), dim=0) h_t_pairs = params['h_t_pairs'] h_t_pairs = h_t_pairs + (h_t_pairs == 0).long() - 1 # [batch_size, h_t_limit, 2] h_t_limit = h_t_pairs.size(1) # [batch_size, h_t_limit, bank_size] h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand( -1, -1, self.bank_size) t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand( -1, -1, self.bank_size) # [batch_size, h_t_limit, bank_size] h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index) t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index) global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1) entity_graphs = params['entity_graphs'] entity_graph_big = dgl.batch(entity_graphs) self.edge_layer(entity_graph_big, entity_graph_feature) entity_graphs = dgl.unbatch(entity_graph_big) path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4))) relation_mask = params['relation_mask'] path_table = params['path_table'] for i in range(len(entity_graphs)): path_t = path_table[i] for j in range(h_t_limit): if relation_mask is not None and relation_mask[i, j].item() == 0: break h = h_t_pairs[i, j, 0].item() t = h_t_pairs[i, j, 1].item() # for evaluate if relation_mask is None and h == 0 and t == 0: continue if (h + 1, t + 1) in path_t: v = [val - 1 for val in path_t[(h + 1, t + 1)]] elif (t + 1, h + 1) in path_t: v = [val - 1 for val in path_t[(t + 1, h + 1)]] else: print(h, t, v) print(entity_graphs[i].all_edges()) print(h_t_pairs) print(relation_mask) assert 1 == 2 middle_node_num = len(v) if middle_node_num == 0: continue # forward edge_ids = get_cuda(entity_graphs[i].edge_ids( [h for _ in range(middle_node_num)], v)) forward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) edge_ids = get_cuda(entity_graphs[i].edge_ids( v, [t for _ in range(middle_node_num)])) forward_second = torch.index_select( entity_graphs[i].edata['h'], dim=0, index=edge_ids) # backward edge_ids = get_cuda(entity_graphs[i].edge_ids( [t for _ in range(middle_node_num)], v)) backward_first = torch.index_select( entity_graphs[i].edata['h'], dim=0, index=edge_ids) edge_ids = get_cuda(entity_graphs[i].edge_ids( v, [h for _ in range(middle_node_num)])) backward_second = torch.index_select( entity_graphs[i].edata['h'], dim=0, index=edge_ids) tmp_path_info = torch.cat((forward_first, forward_second, backward_first, backward_second), dim=-1) _, attn_value = self.attention( torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1), tmp_path_info) path_info[i, j] = attn_value entity_graphs[i].edata.pop('h') path_info = self.dropout( self.activation(self.path_info_mapping(path_info))) predictions = self.predict( torch.cat((h_entity, t_entity, torch.abs(h_entity - t_entity), torch.mul(h_entity, t_entity), global_info, path_info), dim=-1)) return predictions
def collate_movielens(data): g_list, label_list = map(list, zip(*data)) g = dgl.batch_hetero(g_list) g_label = th.stack(label_list) return g, g_label