def test_batching_hetero_and_batched_hetero_topology(index_dtype): """Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph.""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) bg1 = dgl.batch_hetero([g1, g2]) g3 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1)], ('user', 'plays', 'game'): [(1, 0)] }, index_dtype=index_dtype) bg2 = dgl.batch_hetero([bg1, g3]) assert bg2.ntypes == g3.ntypes assert bg2.etypes == g3.etypes assert bg2.canonical_etypes == g3.canonical_etypes assert bg2.batch_size == 3 # Test number of nodes for ntype in bg2.ntypes: assert bg2.batch_num_nodes(ntype) == [ g1.number_of_nodes(ntype), g2.number_of_nodes(ntype), g3.number_of_nodes(ntype)] assert bg2.number_of_nodes(ntype) == ( g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype)) # Test number of edges for etype in bg2.etypes: assert bg2.batch_num_edges(etype) == [ g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)] assert bg2.number_of_edges(etype) == ( g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype)) for etype in bg2.canonical_etypes: assert bg2.batch_num_edges(etype) == [ g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)] assert bg2.number_of_edges(etype) == ( g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype)) # Test relabeled nodes for ntype in bg2.ntypes: assert list(F.asnumpy(bg2.nodes(ntype))) == list(range(bg2.number_of_nodes(ntype))) # Test relabeled edges src, dst = bg2.all_edges(etype='follows') assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6] assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7] src, dst = bg2.all_edges(etype='plays') assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7] assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2] # Test unbatching graphs g4, g5, g6 = dgl.unbatch_hetero(bg2) check_equivalence_between_heterographs(g1, g4) check_equivalence_between_heterographs(g2, g5) check_equivalence_between_heterographs(g3, g6)
def test_batching_hetero_topology(index_dtype): """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'developer'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)] }, index_dtype=index_dtype) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'follows', 'developer'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)] }, index_dtype=index_dtype) bg = dgl.batch_hetero([g1, g2]) assert bg.ntypes == g2.ntypes assert bg.etypes == g2.etypes assert bg.canonical_etypes == g2.canonical_etypes assert bg.batch_size == 2 # Test number of nodes for ntype in bg.ntypes: assert bg.batch_num_nodes(ntype) == [ g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)] assert bg.number_of_nodes(ntype) == ( g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype)) # Test number of edges assert bg.batch_num_edges('plays') == [ g1.number_of_edges('plays'), g2.number_of_edges('plays')] assert bg.number_of_edges('plays') == ( g1.number_of_edges('plays') + g2.number_of_edges('plays')) for etype in bg.canonical_etypes: assert bg.batch_num_edges(etype) == [ g1.number_of_edges(etype), g2.number_of_edges(etype)] assert bg.number_of_edges(etype) == ( g1.number_of_edges(etype) + g2.number_of_edges(etype)) # Test relabeled nodes for ntype in bg.ntypes: assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype))) # Test relabeled edges src, dst = bg.all_edges(etype=('user', 'follows', 'user')) assert list(F.asnumpy(src)) == [0, 1, 4, 5] assert list(F.asnumpy(dst)) == [1, 2, 5, 6] src, dst = bg.all_edges(etype=('user', 'follows', 'developer')) assert list(F.asnumpy(src)) == [0, 1, 4, 5] assert list(F.asnumpy(dst)) == [1, 2, 4, 5] src, dst = bg.all_edges(etype='plays') assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6] assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3] # Test unbatching graphs g3, g4 = dgl.unbatch_hetero(bg) check_equivalence_between_heterographs(g1, g3) check_equivalence_between_heterographs(g2, g4)
def test_batched_features(index_dtype): """Test the features of batched DGLHeteroGraphs""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g1.nodes['game'].data['h1'] = F.tensor([[0.]]) g1.nodes['game'].data['h2'] = F.tensor([[1.]]) g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g1.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g2.nodes['game'].data['h1'] = F.tensor([[0.]]) g2.nodes['game'].data['h2'] = F.tensor([[1.]]) g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) bg = dgl.batch_hetero([g1, g2], node_attrs=ALL, edge_attrs={ ('user', 'follows', 'user'): 'h1', ('user', 'plays', 'game'): None }) assert F.allclose(bg.nodes['user'].data['h1'], F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0)) assert F.allclose(bg.nodes['user'].data['h2'], F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0)) assert F.allclose(bg.nodes['game'].data['h1'], F.cat([g1.nodes['game'].data['h1'], g2.nodes['game'].data['h1']], dim=0)) assert F.allclose(bg.nodes['game'].data['h2'], F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0)) assert F.allclose(bg.edges['follows'].data['h1'], F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0)) assert 'h2' not in bg.edges['follows'].data.keys() assert 'h1' not in bg.edges['plays'].data.keys() # Test unbatching graphs g3, g4 = dgl.unbatch_hetero(bg) check_equivalence_between_heterographs( g1, g3, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']}) check_equivalence_between_heterographs( g2, g4, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']})
def test_batching_with_zero_nodes_edges(index_dtype): """Test the features of batched DGLHeteroGraphs""" g1 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [] }, index_dtype=index_dtype) g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2 = dgl.heterograph({ ('user', 'follows', 'user'): [(0, 1), (1, 2)], ('user', 'plays', 'game'): [(0, 0), (1, 0)] }, index_dtype=index_dtype) g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g2.nodes['game'].data['h1'] = F.tensor([[0.]]) g2.nodes['game'].data['h2'] = F.tensor([[1.]]) g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) bg = dgl.batch_hetero([g1, g2]) assert F.allclose(bg.nodes['user'].data['h1'], F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0)) assert F.allclose(bg.nodes['user'].data['h2'], F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0)) assert F.allclose(bg.nodes['game'].data['h1'], g2.nodes['game'].data['h1']) assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2']) assert F.allclose(bg.edges['follows'].data['h1'], F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0)) assert F.allclose(bg.edges['plays'].data['h1'], g2.edges['plays'].data['h1']) # Test unbatching graphs g3, g4 = dgl.unbatch_hetero(bg) check_equivalence_between_heterographs( g1, g3, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']}) check_equivalence_between_heterographs( g2, g4, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, edge_attrs={('user', 'follows', 'user'): ['h1']}) # Test graphs without edges g1 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(0, 4)) g2 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(1, 5)) g2.nodes['u'].data['x'] = F.tensor([1]) dgl.batch_hetero([g1, g2])
net = hgfp.models.gcn_with_combine_readout.Net( ['128', '0.1', 'relu', '128', '0.1', 'relu', '128', '0.1', 'relu']) # net.load_state_dict(torch.load('/data/chodera/wangyq/hgfp_scripts/gcn_param/2020-03-30_11_41_19/model')) net.load_state_dict(torch.load('model_multi.ds')) mean_and_std_dict = torch.load('/data/chodera/wangyq/hgfp_scripts/gcn_param/2020-03-30_11_41_19/norm_dict') norm, unnorm = hgfp.data.utils.get_norm_fn(mean_and_std_dict) loss_fn = torch.nn.functional.mse_loss for g_, u in ds: g_ = dgl.unbatch_hetero(g_) idxs = random.choices(list(range(len(g_))), k=64) g_ = dgl.batch_hetero( [g_[idx] for idx in idxs]) g = copy.deepcopy(g_) u = u[idxs] g = net(g, return_graph=True) g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz( g)
def forward(self, **params): ''' words: [batch_size, max_length] src_lengths: [batchs_size] mask: [batch_size, max_length] entity_type: [batch_size, max_length] entity_id: [batch_size, max_length] mention_id: [batch_size, max_length] distance: [batch_size, max_length] entity2mention_table: list of [local_entity_num, local_mention_num] graphs: list of DGLHeteroGraph h_t_pairs: [batch_size, h_t_limit, 2] ''' src = self.word_emb(params['words']) mask = params['mask'] bsz, slen, _ = src.size() if self.config.use_entity_type: src = torch.cat( [src, self.entity_type_emb(params['entity_type'])], dim=-1) if self.config.use_entity_id: src = torch.cat([src, self.entity_id_emb(params['entity_id'])], dim=-1) # src: [batch_size, slen, encoder_input_size] # src_lengths: [batchs_size] encoder_outputs, (output_h_t, _) = self.encoder(src, params['src_lengths']) encoder_outputs[mask == 0] = 0 # encoder_outputs: [batch_size, slen, 2*encoder_hid_size] # output_h_t: [batch_size, 2*encoder_hid_size] graphs = params['graphs'] mention_id = params['mention_id'] features = None for i in range(len(graphs)): encoder_output = encoder_outputs[i] # [slen, 2*encoder_hid_size] mention_num = torch.max(mention_id[i]) mention_index = get_cuda( (torch.arange(mention_num) + 1).unsqueeze(1).expand( -1, slen)) # [mention_num, slen] mentions = mention_id[i].unsqueeze(0).expand( mention_num, -1) # [mention_num, slen] select_metrix = ( mention_index == mentions).float() # [mention_num, slen] # average word -> mention word_total_numbers = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand( -1, slen) # [mention_num, slen] select_metrix = torch.where(word_total_numbers > 0, select_metrix / word_total_numbers, select_metrix) x = torch.mm(select_metrix, encoder_output) # [mention_num, 2*encoder_hid_size] x = torch.cat((output_h_t[i].unsqueeze(0), x), dim=0) if features is None: features = x else: features = torch.cat((features, x), dim=0) graph_big = dgl.batch_hetero(graphs) output_features = [features] for GCN_layer in self.GCN_layers: features = GCN_layer( graph_big, {"node": features})["node"] # [total_mention_nums, gcn_dim] output_features.append(features) output_feature = torch.cat(output_features, dim=-1) graphs = dgl.unbatch_hetero(graph_big) # mention -> entity entity2mention_table = params[ 'entity2mention_table'] # list of [entity_num, mention_num] entity_num = torch.max(params['entity_id']) entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size)) global_info = get_cuda(torch.Tensor(bsz, self.bank_size)) cur_idx = 0 entity_graph_feature = None for i in range(len(graphs)): # average mention -> entity select_metrix = entity2mention_table[i].float( ) # [local_entity_num, mention_num] select_metrix[0][0] = 1 mention_nums = torch.sum(select_metrix, dim=-1).unsqueeze(-1).expand( -1, select_metrix.size(1)) select_metrix = torch.where(mention_nums > 0, select_metrix / mention_nums, select_metrix) node_num = graphs[i].number_of_nodes('node') entity_representation = torch.mm( select_metrix, output_feature[cur_idx:cur_idx + node_num]) entity_bank[i, :select_metrix.size(0) - 1] = entity_representation[1:] global_info[i] = output_feature[cur_idx] cur_idx += node_num if entity_graph_feature is None: entity_graph_feature = entity_representation[ 1:, -self.config.gcn_dim:] else: entity_graph_feature = torch.cat( (entity_graph_feature, entity_representation[1:, -self.config.gcn_dim:]), dim=0) h_t_pairs = params['h_t_pairs'] h_t_pairs = h_t_pairs + (h_t_pairs == 0).long() - 1 # [batch_size, h_t_limit, 2] h_t_limit = h_t_pairs.size(1) # [batch_size, h_t_limit, bank_size] h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand( -1, -1, self.bank_size) t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand( -1, -1, self.bank_size) # [batch_size, h_t_limit, bank_size] h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index) t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index) global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1) entity_graphs = params['entity_graphs'] entity_graph_big = dgl.batch(entity_graphs) self.edge_layer(entity_graph_big, entity_graph_feature) entity_graphs = dgl.unbatch(entity_graph_big) path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4))) relation_mask = params['relation_mask'] path_table = params['path_table'] for i in range(len(entity_graphs)): path_t = path_table[i] for j in range(h_t_limit): if relation_mask is not None and relation_mask[i, j].item() == 0: break h = h_t_pairs[i, j, 0].item() t = h_t_pairs[i, j, 1].item() # for evaluate if relation_mask is None and h == 0 and t == 0: continue if (h + 1, t + 1) in path_t: v = [val - 1 for val in path_t[(h + 1, t + 1)]] elif (t + 1, h + 1) in path_t: v = [val - 1 for val in path_t[(t + 1, h + 1)]] else: print(h, t, v) print(entity_graphs[i].all_edges()) print(h_t_pairs) print(relation_mask) assert 1 == 2 middle_node_num = len(v) if middle_node_num == 0: continue # forward edge_ids = get_cuda(entity_graphs[i].edge_ids( [h for _ in range(middle_node_num)], v)) forward_first = torch.index_select(entity_graphs[i].edata['h'], dim=0, index=edge_ids) edge_ids = get_cuda(entity_graphs[i].edge_ids( v, [t for _ in range(middle_node_num)])) forward_second = torch.index_select( entity_graphs[i].edata['h'], dim=0, index=edge_ids) # backward edge_ids = get_cuda(entity_graphs[i].edge_ids( [t for _ in range(middle_node_num)], v)) backward_first = torch.index_select( entity_graphs[i].edata['h'], dim=0, index=edge_ids) edge_ids = get_cuda(entity_graphs[i].edge_ids( v, [h for _ in range(middle_node_num)])) backward_second = torch.index_select( entity_graphs[i].edata['h'], dim=0, index=edge_ids) tmp_path_info = torch.cat((forward_first, forward_second, backward_first, backward_second), dim=-1) _, attn_value = self.attention( torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1), tmp_path_info) path_info[i, j] = attn_value entity_graphs[i].edata.pop('h') path_info = self.dropout( self.activation(self.path_info_mapping(path_info))) predictions = self.predict( torch.cat((h_entity, t_entity, torch.abs(h_entity - t_entity), torch.mul(h_entity, t_entity), global_info, path_info), dim=-1)) return predictions
import hgfp import torch import dgl import numpy as np import copy ds = hgfp.data.ani.df.topology_batched(0, mm=True) loss_fn = torch.nn.functional.mse_loss for g_, u in ds: # k = torch.tensor(np.load('k.npy')) # eq = torch.tensor(np.load('eq.npy')) k = torch.nn.Parameter( torch.zeros_like(dgl.unbatch_hetero(g_)[0].nodes['bond'].data['k'])) eq = torch.nn.Parameter( torch.zeros_like(dgl.unbatch_hetero(g_)[0].nodes['bond'].data['eq'])) opt = torch.optim.Adam([k, eq], 1e-1) for _ in range(100): g = copy.deepcopy(g_) g.nodes['bond'].data['k'] = k.repeat(g.batch_size) g.nodes['bond'].data['eq'] = eq.repeat(g.batch_size) g = hgfp.mm.geometry_in_heterograph.from_heterograph_with_xyz(g) g = hgfp.mm.energy_in_heterograph.u(g) # g = unnorm(g)