e_feat = None if self.gnn_model == "gin": x, all_outputs = self.gnn(g, n_feat, e_feat) else: x, all_outputs = self.gnn(g, n_feat, e_feat), None x = self.set2set(g, x) x = self.lin_readout(x) if self.norm: x = F.normalize(x, p=2, dim=-1, eps=1e-5) if return_all_outputs: return x, all_outputs else: return x if __name__ == "__main__": model = GraphEncoder(gnn_model="gin") print(model) g = dgl.DGLGraph() g.add_nodes(3) g.add_edges([0, 0, 1, 2], [1, 2, 2, 1]) g.ndata["pos_directed"] = torch.rand(3, 16) g.ndata["pos_undirected"] = torch.rand(3, 16) g.ndata["seed"] = torch.zeros(3, dtype=torch.long) g.ndata["nfreq"] = torch.ones(3, dtype=torch.long) g.edata["efreq"] = torch.ones(4, dtype=torch.long) g = dgl.batch([g, g, g]) y = model(g) print(y.shape) print(y)
def batcher_dev(batch): batch_trees = dgl.batch(batch).to(device) return GCNBatch(graph=batch_trees, labels=batch_trees.ndata[LABEL_NODE_NAME])
def batcher_dev(batch): graph_q, graph_k = zip(*batch) graph_q, graph_k = dgl.batch(graph_q), dgl.batch(graph_k) return graph_q, graph_k
def test_moleculenet(): if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') for dataset in ['BACE', 'BBBP', 'ClinTox', 'FreeSolv', 'HIV', 'MUV', 'SIDER', 'ToxCast', 'PCBA', 'ESOL', 'Lipophilicity', 'Tox21']: for featurizer_type in ['canonical', 'attentivefp']: if featurizer_type == 'canonical': node_featurizer = CanonicalAtomFeaturizer(atom_data_field='hv') edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he', self_loop=True) else: node_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='hv') edge_featurizer = AttentiveFPBondFeaturizer(bond_data_field='he', self_loop=True) for model_type in ['GCN', 'GAT']: g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer) g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('{}_{}_{}'.format( model_type, featurizer_type, dataset)).to(device) with torch.no_grad(): model(bg.to(device), bg.ndata.pop('hv').to(device)) model.eval() model(g1.to(device), g1.ndata.pop('hv').to(device)) remove_file('{}_{}_{}_pre_trained.pth'.format( model_type.lower(), featurizer_type, dataset)) for model_type in ['Weave', 'MPNN', 'AttentiveFP']: g1 = smiles_to_bigraph('CO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) g2 = smiles_to_bigraph('CCO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('{}_{}_{}'.format( model_type, featurizer_type, dataset)).to(device) with torch.no_grad(): model(bg.to(device), bg.ndata.pop('hv').to(device), bg.edata.pop('he').to(device)) model.eval() model(g1.to(device), g1.ndata.pop('hv').to(device), g1.edata.pop('he').to(device)) remove_file('{}_{}_{}_pre_trained.pth'.format( model_type.lower(), featurizer_type, dataset)) if dataset == 'ClinTox': continue node_featurizer = PretrainAtomFeaturizer() edge_featurizer = PretrainBondFeaturizer() for model_type in ['gin_supervised_contextpred', 'gin_supervised_infomax', 'gin_supervised_edgepred', 'gin_supervised_masking']: g1 = smiles_to_bigraph('CO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) g2 = smiles_to_bigraph('CCO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('{}_{}'.format(model_type, dataset)).to(device) with torch.no_grad(): node_feats = [ bg.ndata.pop('atomic_number').to(device), bg.ndata.pop('chirality_type').to(device) ] edge_feats = [ bg.edata.pop('bond_type').to(device), bg.edata.pop('bond_direction_type').to(device) ] model(bg.to(device), node_feats, edge_feats) model.eval() node_feats = [ g1.ndata.pop('atomic_number').to(device), g1.ndata.pop('chirality_type').to(device) ] edge_feats = [ g1.edata.pop('bond_type').to(device), g1.edata.pop('bond_direction_type').to(device) ] model(g1.to(device), node_feats, edge_feats) remove_file('{}_{}_pre_trained.pth'.format(model_type.lower(), dataset))
def forward(self, node_num=None, feat=None, spatial_feat=None, word2vec=None, roi_label=None, validation=False, choose_nodes=None, remove_nodes=None): # set up graph batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, batch_readout_edge_list, batch_readout_h_h_e_list, batch_readout_h_o_e_list = [], [], [], [], [], [], [], [], [] node_num_cum = np.cumsum(node_num) # !IMPORTANT for i in range(len(node_num)): # set node space node_space = 0 if i != 0: node_space = node_num_cum[i - 1] graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._build_graph( node_num[i], roi_label[i], node_space, diff_edge=self.diff_edge) # updata batch graph, batch_graph.append(graph) batch_h_node_list += h_node_list batch_obj_node_list += obj_node_list batch_h_h_e_list += h_h_e_list batch_o_o_e_list += o_o_e_list batch_h_o_e_list += h_o_e_list batch_readout_edge_list += readout_edge_list batch_readout_h_h_e_list += readout_h_h_e_list batch_readout_h_o_e_list += readout_h_o_e_list batch_graph = dgl.batch(batch_graph) # ipdb.set_trace() if not self.CONFIG1.feat_type == 'fc7': feat = self.graph_head(feat) # pass throuh gnn/gcn if self.layer == 1: self.grnn1(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, validation, initial_feat=True) # batch_graph.apply_edges(self.edge_readout, tuple(zip(*(batch_readout_h_o_e_list+batch_readout_h_h_e_list)))) batch_graph.apply_edges(self.edge_readout, tuple(zip(*batch_readout_edge_list))) elif self.layer == 2: feat, feat_lang = self.grnn1(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, validation, pop_feat=True, initial_feat=True) self.grnn2(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, feat_lang, validation) if self.diff_edge: # update node feature at the last layer if not len(batch_h_node_list) == 0: batch_graph.apply_nodes(self.h_node_update, batch_h_node_list) if not len(batch_obj_node_list) == 0: batch_graph.apply_nodes(self.o_node_update, batch_obj_node_list) else: batch_graph.apply_nodes( self.h_node_update, batch_h_node_list + batch_obj_node_list) batch_graph.apply_edges( self.edge_readout, tuple( zip(*(batch_readout_h_o_e_list + batch_readout_h_h_e_list)))) else: feat, feat_lang = self.grnn1(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, validation, pop_feat=True, initial_feat=True) feat, feat_lang = self.grnn2(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, feat_lang, validation, pop_feat=True) self.grnn3(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, feat_lang, validation) if self.diff_edge: # update node feature at the last layer if not len(batch_h_node_list) == 0: batch_graph.apply_nodes(self.h_node_update, batch_h_node_list) if not len(batch_obj_node_list) == 0: batch_graph.apply_nodes(self.o_node_update, batch_obj_node_list) else: batch_graph.apply_nodes( self.h_node_update, batch_h_node_list + batch_obj_node_list) batch_graph.apply_edges( self.edge_readout, tuple( zip(*(batch_readout_h_o_e_list + batch_readout_h_h_e_list)))) # import ipdb; ipdb.set_trace() if self.training or validation: # return batch_graph.edges[tuple(zip(*(batch_readout_h_o_e_list+batch_readout_h_h_e_list)))].data['pred'] # !NOTE: cannot use "batch_readout_h_o_e_list+batch_readout_h_h_e_list" because of the wrong order return batch_graph.edges[tuple( zip(*batch_readout_edge_list))].data['pred'] else: return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred'], \ batch_graph.nodes[batch_h_node_list].data['alpha'], \ batch_graph.nodes[batch_h_node_list].data['alpha_lang']
############################################################################### # The learning curve of a run is presented below. plt.title('cross entropy averaged over minibatches') plt.plot(epoch_losses) plt.show() ############################################################################### # The trained model is evaluated on the test set created. To deploy # the tutorial, restrict the running time to get a higher # accuracy (:math:`80` % ~ :math:`90` %) than the ones printed below. model.eval() # Convert a list of tuples to two lists test_X, test_Y = map(list, zip(*testset)) test_bg = dgl.batch(test_X) test_Y = torch.tensor(test_Y).float().view(-1, 1) probs_Y = torch.softmax(model(test_bg), 1) sampled_Y = torch.multinomial(probs_Y, 1) argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1) print('Accuracy of sampled predictions on the test set: {:.4f}%'.format( (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100)) print('Accuracy of argmax predictions on the test set: {:4f}%'.format( (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100)) ############################################################################### # The animation here plots the probability that a trained model predicts the correct graph type. # # .. image:: https://data.dgl.ai/tutorial/batch/test_eval4.gif # # To understand the node and graph representations that a trained model learned,
def test_graph6(): """Batched graph with node types and edge distances.""" g1 = dgl.graph(([0, 0, 1], [1, 2, 2]), idtype=torch.int32) g2 = dgl.graph(([0, 1, 1, 1], [1, 2, 3, 4]), idtype=torch.int32) bg = dgl.batch([g1, g2]) return bg, torch.LongTensor([0, 1, 0, 2, 0, 3, 4, 4]), torch.randn(7, 1)
def test_empty_relation(idtype): """Test the features of batched DGLHeteroGraphs""" g1 = dgl.heterograph( { ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([], []) }, idtype=idtype, device=F.ctx()) g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2 = dgl.heterograph( { ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([0, 1], [0, 0]) }, idtype=idtype, device=F.ctx()) g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g2.nodes['game'].data['h1'] = F.tensor([[0.]]) g2.nodes['game'].data['h2'] = F.tensor([[1.]]) g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) bg = dgl.batch([g1, g2]) # Test number of nodes for ntype in bg.ntypes: assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [ g1.number_of_nodes(ntype), g2.number_of_nodes(ntype) ] # Test number of edges for etype in bg.canonical_etypes: assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [ g1.number_of_edges(etype), g2.number_of_edges(etype) ] # Test features assert F.allclose( bg.nodes['user'].data['h1'], F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0)) assert F.allclose( bg.nodes['user'].data['h2'], F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0)) assert F.allclose(bg.nodes['game'].data['h1'], g2.nodes['game'].data['h1']) assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2']) assert F.allclose( bg.edges['follows'].data['h1'], F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0)) assert F.allclose(bg.edges['plays'].data['h1'], g2.edges['plays'].data['h1']) # Test unbatching graphs g3, g4 = dgl.unbatch(bg) check_equivalence_between_heterographs(g1, g3, node_attrs={ 'user': ['h1', 'h2'], 'game': ['h1', 'h2'] }, edge_attrs={ ('user', 'follows', 'user'): ['h1'] }) check_equivalence_between_heterographs(g2, g4, node_attrs={ 'user': ['h1', 'h2'], 'game': ['h1', 'h2'] }, edge_attrs={ ('user', 'follows', 'user'): ['h1'] }) # Test graphs without edges g1 = dgl.heterograph({('u', 'r', 'v'): ([], [])}, {'u': 0, 'v': 4}) g2 = dgl.heterograph({('u', 'r', 'v'): ([], [])}, {'u': 1, 'v': 5}) dgl.batch([g1, g2])
def test_topology(gs, idtype): """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations""" g1, g2 = gs g1 = g1.astype(idtype).to(F.ctx()) g2 = g2.astype(idtype).to(F.ctx()) bg = dgl.batch([g1, g2]) assert bg.idtype == idtype assert bg.device == F.ctx() assert bg.ntypes == g2.ntypes assert bg.etypes == g2.etypes assert bg.canonical_etypes == g2.canonical_etypes assert bg.batch_size == 2 # Test number of nodes for ntype in bg.ntypes: print(ntype) assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [ g1.number_of_nodes(ntype), g2.number_of_nodes(ntype) ] assert bg.number_of_nodes(ntype) == (g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype)) # Test number of edges for etype in bg.canonical_etypes: assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [ g1.number_of_edges(etype), g2.number_of_edges(etype) ] assert bg.number_of_edges(etype) == (g1.number_of_edges(etype) + g2.number_of_edges(etype)) # Test relabeled nodes for ntype in bg.ntypes: assert list(F.asnumpy(bg.nodes(ntype))) == list( range(bg.number_of_nodes(ntype))) # Test relabeled edges src, dst = bg.edges(etype=('user', 'follows', 'user')) assert list(F.asnumpy(src)) == [0, 1, 4, 5] assert list(F.asnumpy(dst)) == [1, 2, 5, 6] src, dst = bg.edges(etype=('user', 'follows', 'developer')) assert list(F.asnumpy(src)) == [0, 1, 4, 5] assert list(F.asnumpy(dst)) == [1, 2, 4, 5] src, dst, eid = bg.edges(etype='plays', form='all') assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6] assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3] assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6] # Test unbatching graphs g3, g4 = dgl.unbatch(bg) check_equivalence_between_heterographs(g1, g3) check_equivalence_between_heterographs(g2, g4) # Test dtype cast if idtype == "int32": bg_cast = bg.long() else: bg_cast = bg.int() assert bg.batch_size == bg_cast.batch_size # Test local var bg_local = bg.local_var() assert bg.batch_size == bg_local.batch_size
def test_batching_batched(idtype): """Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph.""" g1 = dgl.heterograph( { ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([0, 1], [0, 0]) }, idtype=idtype, device=F.ctx()) g2 = dgl.heterograph( { ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([0, 1], [0, 0]) }, idtype=idtype, device=F.ctx()) bg1 = dgl.batch([g1, g2]) g3 = dgl.heterograph( { ('user', 'follows', 'user'): ([0], [1]), ('user', 'plays', 'game'): ([1], [0]) }, idtype=idtype, device=F.ctx()) bg2 = dgl.batch([bg1, g3]) assert bg2.idtype == idtype assert bg2.device == F.ctx() assert bg2.ntypes == g3.ntypes assert bg2.etypes == g3.etypes assert bg2.canonical_etypes == g3.canonical_etypes assert bg2.batch_size == 3 # Test number of nodes for ntype in bg2.ntypes: assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [ g1.number_of_nodes(ntype), g2.number_of_nodes(ntype), g3.number_of_nodes(ntype) ] assert bg2.number_of_nodes(ntype) == (g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype)) # Test number of edges for etype in bg2.canonical_etypes: assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [ g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype) ] assert bg2.number_of_edges(etype) == (g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype)) # Test relabeled nodes for ntype in bg2.ntypes: assert list(F.asnumpy(bg2.nodes(ntype))) == list( range(bg2.number_of_nodes(ntype))) # Test relabeled edges src, dst = bg2.edges(etype='follows') assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6] assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7] src, dst = bg2.edges(etype='plays') assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7] assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2] # Test unbatching graphs g4, g5, g6 = dgl.unbatch(bg2) check_equivalence_between_heterographs(g1, g4) check_equivalence_between_heterographs(g2, g5) check_equivalence_between_heterographs(g3, g6)
def test_features(idtype): """Test the features of batched DGLHeteroGraphs""" g1 = dgl.heterograph( { ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([0, 1], [0, 0]) }, idtype=idtype, device=F.ctx()) g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g1.nodes['game'].data['h1'] = F.tensor([[0.]]) g1.nodes['game'].data['h2'] = F.tensor([[1.]]) g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g1.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) g2 = dgl.heterograph( { ('user', 'follows', 'user'): ([0, 1], [1, 2]), ('user', 'plays', 'game'): ([0, 1], [0, 0]) }, idtype=idtype, device=F.ctx()) g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]]) g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]]) g2.nodes['game'].data['h1'] = F.tensor([[0.]]) g2.nodes['game'].data['h2'] = F.tensor([[1.]]) g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]]) g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]]) g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]]) # test default setting bg = dgl.batch([g1, g2]) assert F.allclose( bg.nodes['user'].data['h1'], F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0)) assert F.allclose( bg.nodes['user'].data['h2'], F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0)) assert F.allclose( bg.nodes['game'].data['h1'], F.cat([g1.nodes['game'].data['h1'], g2.nodes['game'].data['h1']], dim=0)) assert F.allclose( bg.nodes['game'].data['h2'], F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0)) assert F.allclose( bg.edges['follows'].data['h1'], F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0)) assert F.allclose( bg.edges['follows'].data['h2'], F.cat([g1.edges['follows'].data['h2'], g2.edges['follows'].data['h2']], dim=0)) assert F.allclose( bg.edges['plays'].data['h1'], F.cat([g1.edges['plays'].data['h1'], g2.edges['plays'].data['h1']], dim=0)) # test specifying ndata/edata bg = dgl.batch([g1, g2], ndata=['h2'], edata=['h1']) assert F.allclose( bg.nodes['user'].data['h2'], F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0)) assert F.allclose( bg.nodes['game'].data['h2'], F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0)) assert F.allclose( bg.edges['follows'].data['h1'], F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0)) assert F.allclose( bg.edges['plays'].data['h1'], F.cat([g1.edges['plays'].data['h1'], g2.edges['plays'].data['h1']], dim=0)) assert 'h1' not in bg.nodes['user'].data assert 'h1' not in bg.nodes['game'].data assert 'h2' not in bg.edges['follows'].data # Test unbatching graphs g3, g4 = dgl.unbatch(bg) check_equivalence_between_heterographs(g1, g3, node_attrs={ 'user': ['h2'], 'game': ['h2'] }, edge_attrs={ ('user', 'follows', 'user'): ['h1'] }) check_equivalence_between_heterographs(g2, g4, node_attrs={ 'user': ['h2'], 'game': ['h2'] }, edge_attrs={ ('user', 'follows', 'user'): ['h1'] }) # test legacy bg = dgl.batch([g1, g2], edge_attrs=['h1']) assert 'h2' not in bg.edges['follows'].data.keys()
def forward(self, batch): # ====================================================================================== # 数据处理 # ====================================================================================== batch_size = len(batch['facts_num_nodes_list']) images = batch['features_list'] # [(36,2048)] images = torch.stack(images).to(self.device) # [batch,36,2048] img_relations = batch['img_relations_list'] img_relations = torch.stack(img_relations).to(self.device) # shape (batch,36,36,7) 暂定7维 questions = batch['question_list'] # list((max_length,)) questions = torch.stack(questions).long().to(self.device) # [batch,max_length] questions_len_list = batch['question_length_list'] questions_len_list = torch.tensor(batch['question_length_list']).long().to(self.device) fact_num_nodes_list = torch.Tensor(batch['facts_num_nodes_list']).long().to(self.device) facts_features_list = batch['facts_features_list'] facts_features_list = [torch.Tensor(features).to(self.device) for features in facts_features_list ] facts_e1ids_list = batch['facts_e1ids_list'] facts_e1ids_list = [ torch.Tensor(e1ids).long().to(self.device) for e1ids in facts_e1ids_list ] facts_e2ids_list = batch['facts_e2ids_list'] facts_e2ids_list = [ torch.tensor(e2ids).long().to(self.device) for e2ids in facts_e2ids_list ] facts_answer_list = batch['facts_answer_list'] facts_answer_list = [ torch.tensor(answer).long().to(self.device) for answer in facts_answer_list ] ques_embed = self.que_glove_embed(questions).float() _, (ques_embed, _)=self.ques_rnn(ques_embed,questions_len_list) node_att_proj_ques_embed = self.node_att_proj_ques(ques_embed) node_att_proj_img_embed = self.node_att_proj_img(images) node_att_proj_ques_embed = node_att_proj_ques_embed.unsqueeze(1).repeat(1, images.shape[1], 1) node_att_proj_img_sum_ques = torch.tanh(node_att_proj_ques_embed + node_att_proj_img_embed) node_att_values = self.node_att_value(node_att_proj_img_sum_ques).squeeze() node_att_values = F.softmax(node_att_values, dim=-1) node_att_values = node_att_values.unsqueeze(-1).repeat( 1, 1, self.config['model']['img_feature_size']) images = node_att_values * images # (b,36)*(b,36,2048) rel_att_proj_ques_embed = self.rel_att_proj_ques( ques_embed) # shape(batch,128) rel_att_proj_rel_embed = self.rel_att_proj_rel( img_relations) # shape(batch,36,36,128) rel_att_proj_ques_embed = rel_att_proj_ques_embed.repeat( 1, 36 * 36).view( batch_size, 36, 36, self.config['model'] ['rel_att_ques_rel_proj_dims']) # shape(batch,36,36,128) rel_att_proj_rel_sum_ques = torch.tanh(rel_att_proj_ques_embed + rel_att_proj_rel_embed) rel_att_values = self.rel_att_value( rel_att_proj_rel_sum_ques).squeeze() # shape(batch,36,36) rel_att_values_2 = rel_att_values.unsqueeze(-1).repeat( 1, 1, 1, self.config['model']['relation_dims']) img_relations = rel_att_values_2 * img_relations # (batch,36,36,7) img_graphs = [] for i in range(batch_size): g = dgl.DGLGraph() # add nodes g.add_nodes(36) # add node features g.ndata['h'] = images[i] g.ndata['batch'] = torch.full([36, 1], i) # add edges for s in range(36): for d in range(36): g.add_edge(s, d) # add edge features g.edata['rel'] = img_relations[i].view( 36 * 36, self.config['model']['relation_dims']) # shape(36*36,7) g.edata['att'] = rel_att_values[i].view(36 * 36, 1) # shape(36*36,1) img_graphs.append(g) image_batch_graph = dgl.batch(img_graphs) fact_graphs = [] for i in range(batch_size): graph = dgl.DGLGraph() graph.add_nodes(fact_num_nodes_list[i]) graph.add_edges(facts_e1ids_list[i], facts_e2ids_list[i]) graph.ndata['h'] = facts_features_list[i] graph.ndata['batch'] = torch.full([fact_num_nodes_list[i], 1], i) graph.ndata['answer'] = facts_answer_list[i] fact_graphs.append(graph) fact_batch_graph = dgl.batch(fact_graphs) image_batch_graph = self.img_gcn1(image_batch_graph) fact_batch_graph = self.new_fact_gcn1(fact_batch_graph, image_batch_graph, ques_embed=ques_embed) fact_batch_graph.ndata['h'] = F.relu(fact_batch_graph.ndata['h']) image_batch_graph = self.img_gcn2(image_batch_graph) fact_batch_graph = self.new_fact_gcn2( fact_batch_graph, image_batch_graph, ques_embed=ques_embed) # 每个节点1 个特征 fact_batch_graph.ndata['h'] = torch.sigmoid(fact_batch_graph.ndata['h']) fact_batch_graph.ndata['h'] = torch.softmax( fact_batch_graph.ndata['h'], dim=0) return fact_batch_graph
def create_batch(batch_all): #print('Batch size : ', len(batch_all)) #X, Y = batch_all[0] batch = [ProcessImage(batch_itr) for batch_itr in batch_all] lengths = [sample['seq_length'] for sample in batch] Input = [torch.FloatTensor(sample['Input']) for sample in batch] target = [torch.FloatTensor(sample['target']) for sample in batch] #position = [ torch.FloatTensor(sample['position']) for sample in batch ] graph = [sample['gr'] for sample in batch] energy = [torch.FloatTensor(sample['energy']) for sample in batch] position = [torch.Tensor(sample['point_xyz']) for sample in batch] position_idx = [ torch.LongTensor(sample['point_idx_zxy']) for sample in batch ] max_length = np.max(lengths) n_sequences = len(target) Input_tensor = torch.ones((n_sequences, 7, 64, 64)).float() * -5 targets_tensor = torch.ones((n_sequences, 6, 64, 64)).float() * -5 position_idx_tensor = torch.ones((n_sequences, max_length, 3)).int() * -1 position_idx_tensor = position_idx_tensor.long() position_tensor = torch.ones((n_sequences, max_length, 3)).int() * -999. energy_tensor = torch.zeros((n_sequences, max_length)).float() for i in range(n_sequences): seq_len = lengths[i] Input_tensor[i] = Input[i] targets_tensor[i] = target[i] position_tensor[i, :seq_len] = position[i] position_idx_tensor[i, :seq_len] = position_idx[i] energy_tensor[i, :seq_len] = energy[i] sequence_lengths = torch.LongTensor(lengths) # sequence_lengths, idx = sequence_lengths.sort(dim=0, descending=True) # targets_tensor = targets_tensor[idx] # energy_tensor = energy_tensor[idx] # position_tensor = position_tensor[idx] # position_idx_tensor = position_idx_tensor[idx] # Input_tensor = Input_tensor[idx] # graph = sorted(graph, key=lambda g: g.number_of_nodes(), reverse=True) # pos = torch.where(sequence_lengths > graph_size)[0] # sequence_lengths = sequence_lengths[pos] # targets_tensor = targets_tensor[pos] # energy_tensor = energy_tensor[pos] # position_tensor = position_tensor[pos] # position_idx_tensor = position_idx_tensor[pos] # Input_tensor = Input_tensor[pos] # graph = graph[0: len(pos) ] return dgl.batch(graph), targets_tensor, sequence_lengths, energy_tensor, position_tensor, position_idx_tensor,\ Input_tensor
def collate_fn(batch): graphs, pmpds, labels = zip(*batch) batched_graphs = dgl.batch(graphs) batched_pmpds = sp.block_diag(pmpds) batched_labels = np.concatenate(labels, axis=0) return batched_graphs, batched_pmpds, batched_labels
def test_pickling_graph(): # graph structures and frames are pickled g = dgl.DGLGraph() g.add_nodes(3) src = F.tensor([0, 0]) dst = F.tensor([1, 2]) g.add_edges(src, dst) x = F.randn((3, 7)) y = F.randn((3, 5)) a = F.randn((2, 6)) b = F.randn((2, 4)) g.ndata['x'] = x g.ndata['y'] = y g.edata['a'] = a g.edata['b'] = b # registered functions are pickled g.register_message_func(_global_message_func) reduce_func = fn.sum('x', 'x') g.register_reduce_func(reduce_func) # custom attributes should be pickled g.foo = 2 new_g = _reconstruct_pickle(g) _assert_is_identical(g, new_g) assert new_g.foo == 2 assert new_g._message_func == _global_message_func assert isinstance(new_g._reduce_func, type(reduce_func)) assert new_g._reduce_func._name == 'sum' assert new_g._reduce_func.msg_field == 'x' assert new_g._reduce_func.out_field == 'x' # test batched graph with partial set case g2 = dgl.DGLGraph() g2.add_nodes(4) src2 = F.tensor([0, 1]) dst2 = F.tensor([2, 3]) g2.add_edges(src2, dst2) x2 = F.randn((4, 7)) y2 = F.randn((3, 5)) a2 = F.randn((2, 6)) b2 = F.randn((2, 4)) g2.ndata['x'] = x2 g2.nodes[[0, 1, 3]].data['y'] = y2 g2.edata['a'] = a2 g2.edata['b'] = b2 bg = dgl.batch([g, g2]) bg2 = _reconstruct_pickle(bg) _assert_is_identical(bg, bg2) new_g, new_g2 = dgl.unbatch(bg2) _assert_is_identical(g, new_g) _assert_is_identical(g2, new_g2) # readonly graph g = dgl.DGLGraph([(0, 1), (1, 2)], readonly=True) new_g = _reconstruct_pickle(g) _assert_is_identical(g, new_g) # multigraph g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)]) new_g = _reconstruct_pickle(g) _assert_is_identical(g, new_g) # readonly multigraph g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], readonly=True) new_g = _reconstruct_pickle(g) _assert_is_identical(g, new_g)
def main(training_file, dev_file, test_file, task, previous_model=None): global stop_training graph_type = task['graph_type'] net = task['gnn_network'] epochs = task['epochs'] patience = task['patience'] grid_width = task['grid_width'] image_width = task['image_width'] batch_size = task['batch_size'] num_hidden = task['num_gnn_units'] heads = task['num_gnn_heads'] residual = False lr = task['lr'] weight_decay = task['weight_decay'] gnn_layers = task['num_gnn_layers'] cnn_layers = task['num_cnn_layers'] in_drop = task['in_drop'] alpha = task['alpha'] attn_drop = task['attn_drop'] num_rels = task['num_rels'] fw = task['fw'] identifier = task['identifier'] nonlinearity = task['non-linearity'] min_train_loss = float("inf") min_dev_loss = float("inf") if nonlinearity == 'relu': nonlinearity = F.relu elif nonlinearity == 'elu': nonlinearity = F.elu output_list_records_train_loss = [] output_list_records_dev_loss = [] loss_fcn = torch.nn.MSELoss() #(reduction='sum') print('=========================') print('HEADS', heads) print('GNN LAYERS', gnn_layers) print('CNN LAYERS', cnn_layers) print('HIDDEN', num_hidden) print('RESIDUAL', residual) print('inDROP', in_drop) print('atDROP', attn_drop) print('LR', lr) print('DECAY', weight_decay) print('ALPHA', alpha) print('BATCH', batch_size) print('GRAPH_ALT', graph_type) print('ARCHITECTURE', net) print('=========================') # create the dataset print('Loading training set...') train_dataset = socnavImg.SocNavDataset(training_file, mode='train') print('Loading dev set...') valid_dataset = socnavImg.SocNavDataset(dev_file, mode='valid') print('Loading test set...') test_dataset = socnavImg.SocNavDataset(test_file, mode='test') print('Done loading files') train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate) valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate) if num_rels < 0: num_rels = len(socnavImg.get_relations()) cur_step = 0 best_loss = -1 num_feats = train_dataset.graphs[0].ndata['h'].shape[1] print('Number of features: {}'.format(num_feats)) g = dgl.batch(train_dataset.graphs) # define the model if fw == 'dgl': if net in ['gat']: model = GAT( g, # graph gnn_layers, # gnn_layers num_feats, # in_dimension num_hidden, # num_hidden 1, grid_width, # grid_width heads, # head nonlinearity, # activation in_drop, # feat_drop attn_drop, # attn_drop alpha, # negative_slope residual, # residual cnn_layers # cnn_layers ) elif net in ['gatmc']: model = GATMC( g, # graph gnn_layers, # gnn_layers num_feats, # in_dimension num_hidden, # num_hidden grid_width, # grid_width image_width, # image_width heads, # head nonlinearity, # activation in_drop, # feat_drop attn_drop, # attn_drop alpha, # negative_slope residual, # residual cnn_layers # cnn_layers ) elif net in ['rgcn']: print( f'CREATING RGCN(GRAPH, gnn_layers:{gnn_layers}, cnn_layers:{cnn_layers}, num_feats:{num_feats}, num_hidden:{num_hidden}, grid_with:{grid_width}, image_width:{image_width}, num_rels:{num_rels}, non-linearity:{nonlinearity}, drop:{in_drop}, num_bases:{num_rels})' ) model = RGCN(g, gnn_layers, cnn_layers, num_feats, num_hidden, grid_width, image_width, num_rels, nonlinearity, in_drop, num_bases=num_rels) else: print('No valid GNN model specified') sys.exit(0) optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) if previous_model is not None: model.load_state_dict(torch.load(previous_model, map_location=device)) model = model.to(device) for epoch in range(epochs): if stop_training: print("Stopping training. Please wait.") break model.train() loss_list = [] for batch, data in enumerate(train_dataloader): subgraph, labels = data subgraph.set_n_initializer(dgl.init.zero_initializer) subgraph.set_e_initializer(dgl.init.zero_initializer) feats = subgraph.ndata['h'].to(device) labels = labels.to(device) if fw == 'dgl': model.g = subgraph for layer in model.layers: layer.g = subgraph if net in ['rgcn']: logits = model( feats.float(), subgraph.edata['rel_type'].squeeze().to(device)) else: logits = model(feats.float()) else: print('Only DGL is supported at the moment here.') sys.exit(1) if net in ['pgat', 'pgcn']: data = Data(x=feats.float(), edge_index=torch.stack( subgraph.edges()).to(device)) else: data = Data( x=feats.float(), edge_index=torch.stack(subgraph.edges()).to(device), edge_type=subgraph.edata['rel_type'].squeeze().to( device)) logits = model(data, subgraph) a = logits.flatten() b = labels.float().flatten() ad = a.to(device) bd = b.to(device) loss = loss_fcn(ad, bd) optimizer.zero_grad() a = list(model.parameters())[0].clone() loss.backward() optimizer.step() b = list(model.parameters())[0].clone() not_learning = torch.equal(a.data, b.data) if not_learning: print('Not learning') sys.exit(1) else: pass loss_list.append(loss.item()) loss_data = np.array(loss_list).mean() if loss_data < min_train_loss: min_train_loss = loss_data print('Loss: {}'.format(loss_data)) output_list_records_train_loss.append(float(loss_data)) if epoch % 5 == 0: print("Epoch {:05d} | Loss: {:.6f} | Patience: {} | ".format( epoch, loss_data, cur_step), end='') score_list = [] val_loss_list = [] for batch, valid_data in enumerate(valid_dataloader): subgraph, labels = valid_data subgraph.set_n_initializer(dgl.init.zero_initializer) subgraph.set_e_initializer(dgl.init.zero_initializer) feats = subgraph.ndata['h'].to(device) labels = labels.to(device) score, val_loss = evaluate(feats.float(), model, subgraph, labels.float(), loss_fcn, fw, net) score_list.append(score) val_loss_list.append(val_loss) mean_score = np.array(score_list).mean() mean_val_loss = np.array(val_loss_list).mean() print("Score: {:.6f} MEAN: {:.6f} BEST: {:.6f}".format( mean_score, mean_val_loss, best_loss)) output_list_records_dev_loss.append(mean_val_loss) # early stop if best_loss > mean_val_loss or best_loss < 0: print('Saving...') directory = str(identifier).zfill(5) try: os.mkdir(directory) except: print('Exception creating directory', directory) best_loss = mean_val_loss if best_loss < min_dev_loss: min_dev_loss = best_loss # Save the model model.eval() torch.save(model.state_dict(), directory + '/SNGNN2D.tch') params = { 'train_loss': min_train_loss, 'dev_loss': min_dev_loss, 'net': net, 'fw': fw, 'gnn_layers': gnn_layers, 'cnn_layers': cnn_layers, 'num_feats': num_feats, 'num_hidden': num_hidden, 'graph_type': graph_type, 'heads': heads, 'grid_width': grid_width, 'image_width': image_width, 'F': nonlinearity, 'in_drop': in_drop, 'attn_drop': attn_drop, 'alpha': alpha, 'residual': residual, 'num_rels': num_rels, 'train_scores': output_list_records_train_loss, 'dev_scores': output_list_records_dev_loss } pickle.dump(params, open(directory + '/SNGNN2D.prms', 'wb')) cur_step = 0 else: cur_step += 1 if cur_step >= patience: break time_a = time.time() test_score_list = [] for batch, test_data in enumerate(test_dataloader): subgraph, labels = test_data subgraph.set_n_initializer(dgl.init.zero_initializer) subgraph.set_e_initializer(dgl.init.zero_initializer) feats = subgraph.ndata['h'].to(device) labels = labels.to(device) test_score_list.append( evaluate(feats, model, subgraph, labels.float(), loss_fcn, fw, net)[1]) time_b = time.time() time_delta = float(time_b - time_a) test_loss = np.array(test_score_list).mean() print("MSE for the test set {}".format(test_loss)) return min_train_loss, min_dev_loss, test_loss, time_delta, num_of_params( model ), epoch, output_list_records_train_loss, output_list_records_dev_loss
def collate(sample): graphs, feats, labels = map(list, zip(*sample)) graph = dgl.batch(graphs) feats = torch.from_numpy(np.concatenate(feats)) labels = torch.from_numpy(np.concatenate(labels)) return graph, feats, labels
def batcher_dev(batch): graphs, labels = zip(*batch) batch_graphs = dgl.batch(graphs) labels = torch.stack(labels, 0) return AlchemyBatcher(graph=batch_graphs, label=labels)
def collate(samples): # The input `samples` is a list of pairs # (graph, label). graphs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) return batched_graph, torch.tensor(labels)
def __init__(self, tree_list): graph_list = [] for tree in tree_list: graph_list.append(tree.dgl_graph) self.batch_dgl_graph = dgl.batch(graph_list)
def collate(samples): graphs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) return batched_graph, torch.tensor(labels)
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id): nodes_dict = mol_tree_msg.nodes_dict fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None cur_node = nodes_dict[cur_node_id] fa_nid = fa_node['nid'] if fa_node is not None else -1 prev_nodes = [fa_node] if fa_node is not None else [] children_node_id = [ v for v in mol_tree_msg.successors(cur_node_id).tolist() if nodes_dict[v]['nid'] != fa_nid ] children = [nodes_dict[v] for v in children_node_id] neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1] neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1] neighbors = singletons + neighbors cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']] cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap) if len(cands) == 0: return None cand_smiles, cand_mols, cand_amap = list(zip(*cands)) cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols] cand_graphs, atom_x, bond_x, tree_mess_src_edges, \ tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec( cands) cand_graphs = batch(cand_graphs) atom_x = cuda(atom_x) bond_x = cuda(bond_x) cand_graphs.ndata['x'] = atom_x cand_graphs.edata['x'] = bond_x cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_() cand_vecs = self.jtmpn( (cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes), mol_tree_msg, ) cand_vecs = self.G_mean(cand_vecs) mol_vec = mol_vec.squeeze() scores = cand_vecs @ mol_vec _, cand_idx = torch.sort(scores, descending=True) backup_mol = Chem.RWMol(cur_mol) for i in range(len(cand_idx)): cur_mol = Chem.RWMol(backup_mol) pred_amap = cand_amap[cand_idx[i].item()] new_global_amap = copy.deepcopy(global_amap) for nei_id, ctr_atom, nei_atom in pred_amap: if nei_id == fa_nid: continue new_global_amap[nei_id][nei_atom] = new_global_amap[ cur_node['nid']][ctr_atom] cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap) new_mol = cur_mol.GetMol() new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) if new_mol is None: continue result = True for nei_node_id, nei_node in zip(children_node_id, children): if nei_node['is_leaf']: continue cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap, nei_node_id, cur_node_id) if cur_mol is None: result = False break if result: return cur_mol return None
def pack_graph(graphs: Union[Tuple, List]) -> Tuple: # merge many dgl graphs into a huge one root_indices, node_nums, = get_root_node_info(graphs) packed_graph = dgl.batch(graphs) return packed_graph, root_indices, node_nums,
def __call__(self, src_buf, tgt_buf, device='cpu', src_deps=None): ''' Return a batched graph for the training phase of Transformer. args: src_buf: a set of input sequence arrays. tgt_buf: a set of output sequence arrays. device: 'cpu' or 'cuda:*' src_deps: list, optional Dependency parses of the source in the form of src_node_id -> dst_node_id. where src is the child and dst is the parent. i.e a child node attends on its syntactic parent in a dependency parse ''' if src_deps is None: src_deps = list() g_list = [] src_lens = [len(_) for _ in src_buf] tgt_lens = [len(_) - 1 for _ in tgt_buf] num_edges = {'ee': [], 'ed': [], 'dd': []} # We are running over source and target pairs here for src_len, tgt_len in zip(src_lens, tgt_lens): i, j = src_len - 1, tgt_len - 1 g_list.append(self.g_pool[i][j]) for key in ['ee', 'ed', 'dd']: num_edges[key].append(int(self.num_edges[key][i][j])) g = dgl.batch(g_list) src, tgt, tgt_y = [], [], [] src_pos, tgt_pos = [], [] enc_ids, dec_ids = [], [] e2e_eids, d2d_eids, e2d_eids = [], [], [] layer_eids = {'dep': [[], []]} n_nodes, n_edges, n_tokens = 0, 0, 0 for src_sample, tgt_sample, src_dep, n, m, n_ee, n_ed, n_dd in zip( src_buf, tgt_buf, src_deps, src_lens, tgt_lens, num_edges['ee'], num_edges['ed'], num_edges['dd']): src.append(th.tensor(src_sample, dtype=th.long, device=device)) tgt.append(th.tensor(tgt_sample[:-1], dtype=th.long, device=device)) tgt_y.append( th.tensor(tgt_sample[1:], dtype=th.long, device=device)) src_pos.append(th.arange(n, dtype=th.long, device=device)) tgt_pos.append(th.arange(m, dtype=th.long, device=device)) enc_ids.append( th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)) n_nodes += n dec_ids.append( th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device)) n_nodes += m e2e_eids.append( th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)) # Copy the ids of edges that correspond to a given node and its previous N nodes # We are using arange here. This will not work. Instead we need to select edges that # correspond to previous positions. This information is present in graph pool # For each edge, we need to figure out source_node_id and target_node_id. if src_dep: for i in range(0, 2): for src_node_id, dst_node_id in dedupe_tuples( get_src_dst_deps(src_dep, i + 1)): layer_eids['dep'][i].append(n_edges + src_node_id * n + dst_node_id) layer_eids['dep'][i].append(n_edges + dst_node_id * n + src_node_id) n_edges += n_ee e2d_eids.append( th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)) n_edges += n_ed d2d_eids.append( th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)) n_edges += n_dd n_tokens += m g.set_n_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer) return Graph(g=g, src=(th.cat(src), th.cat(src_pos)), tgt=(th.cat(tgt), th.cat(tgt_pos)), tgt_y=th.cat(tgt_y), nids={ 'enc': th.cat(enc_ids), 'dec': th.cat(dec_ids) }, eids={ 'ee': th.cat(e2e_eids), 'ed': th.cat(e2d_eids), 'dd': th.cat(d2d_eids) }, nid_arr={ 'enc': enc_ids, 'dec': dec_ids }, n_nodes=n_nodes, layer_eids={ 'dep': [ th.tensor(layer_eids['dep'][i]) for i in range(0, len(layer_eids['dep'])) ] }, n_edges=n_edges, n_tokens=n_tokens)
axs[0, 0].set_yticks([]) for row, mode in enumerate(['spmm', 'sddmm']): filename = 'reddit_{}.csv'.format(mode) Xs, coo_gflops, csr_gflops = get_data(filename) axs[row + 1, 0].plot(Xs, coo_gflops, linewidth=1, label='coo') axs[row + 1, 0].plot(Xs, csr_gflops, linewidth=1, label='csr') axs[row + 1, 0].set_yticks([0, 200, 400]) axs[row + 1, 0].set_xscale('log', basex=2) axs[row + 1, 0].set_xticks(Xs) axs[2, 0].set_xlabel('reddit') # mesh for i, k in enumerate([32]): #enumerate([8, 16, 32, 64]): f = open('modelnet40_{}.g'.format(k), 'rb') gs = pickle.load(f) g = dgl.batch(gs) axs[0, i + 1].hist(g.in_degrees(), range=(0, 100)) axs[0, i + 1].set_yticks([]) for row, mode in enumerate(['spmm', 'sddmm']): filename = 'mesh_{}_{}.csv'.format(k, mode) Xs, coo_gflops, csr_gflops = get_data(filename) axs[row + 1, i + 1].plot(Xs, coo_gflops, linewidth=1, label='coo') axs[row + 1, i + 1].plot(Xs, csr_gflops, linewidth=1, label='csr') axs[row + 1, i + 1].set_yticks([0, 200, 400]) axs[row + 1, i + 1].set_xscale('log', basex=2) axs[row + 1, i + 1].set_xticks(Xs) axs[2, i + 1].set_xlabel('mesh {}'.format(k)) """ # ppi ppi = PPIDataset('train')
def beam(self, src_buf, start_sym, max_len, k, device='cpu', src_deps=None): ''' Return a batched graph for beam search during inference of Transformer. args: src_buf: a list of input sequence start_sym: the index of start-of-sequence symbol max_len: maximum length for decoding k: beam size device: 'cpu' or 'cuda:*' ''' if src_deps is None: src_deps = list() g_list = [] src_lens = [len(_) for _ in src_buf] tgt_lens = [max_len] * len(src_buf) num_edges = {'ee': [], 'ed': [], 'dd': []} for src_len, tgt_len in zip(src_lens, tgt_lens): i, j = src_len - 1, tgt_len - 1 for _ in range(k): g_list.append(self.g_pool[i][j]) for key in ['ee', 'ed', 'dd']: num_edges[key].append(int(self.num_edges[key][i][j])) g = dgl.batch(g_list) src, tgt = [], [] src_pos, tgt_pos = [], [] enc_ids, dec_ids = [], [] layer_eids = {'dep': [[], []]} e2e_eids, e2d_eids, d2d_eids = [], [], [] n_nodes, n_edges, n_tokens = 0, 0, 0 for src_sample, src_dep, n, n_ee, n_ed, n_dd in zip( src_buf, src_deps, src_lens, num_edges['ee'], num_edges['ed'], num_edges['dd']): for _ in range(k): src.append(th.tensor(src_sample, dtype=th.long, device=device)) src_pos.append(th.arange(n, dtype=th.long, device=device)) enc_ids.append( th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)) n_nodes += n e2e_eids.append( th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)) # Copy the ids of edges that correspond to a given node and its previous N nodes # We are using arange here. This will not work. Instead we need to select edges that # correspond to previous positions. This information is present in graph pool # For each edge, we need to figure out source_node_id and target_node_id. if src_dep: for i in range(0, 2): for src_node_id, dst_node_id in dedupe_tuples( get_src_dst_deps(src_dep, i + 1)): layer_eids['dep'][i].append(n_edges + src_node_id * n + dst_node_id) layer_eids['dep'][i].append(n_edges + dst_node_id * n + src_node_id) n_edges += n_ee tgt_seq = th.zeros(max_len, dtype=th.long, device=device) tgt_seq[0] = start_sym tgt.append(tgt_seq) tgt_pos.append(th.arange(max_len, dtype=th.long, device=device)) dec_ids.append( th.arange(n_nodes, n_nodes + max_len, dtype=th.long, device=device)) n_nodes += max_len e2d_eids.append( th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)) n_edges += n_ed d2d_eids.append( th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)) n_edges += n_dd g.set_n_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer) return Graph(g=g, src=(th.cat(src), th.cat(src_pos)), tgt=(th.cat(tgt), th.cat(tgt_pos)), tgt_y=None, nids={ 'enc': th.cat(enc_ids), 'dec': th.cat(dec_ids) }, eids={ 'ee': th.cat(e2e_eids), 'ed': th.cat(e2d_eids), 'dd': th.cat(d2d_eids) }, nid_arr={ 'enc': enc_ids, 'dec': dec_ids }, n_nodes=n_nodes, n_edges=n_edges, layer_eids={ 'dep': [ th.tensor(layer_eids['dep'][i]) for i in range(0, len(layer_eids['dep'])) ] }, n_tokens=n_tokens)
split = int(n * .8) index = np.arange(n) np.random.seed(32) np.random.shuffle(index) train_index, test_index = index[:split], index[split:] # prep labels train_labels, test_labels = Variable( torch.LongTensor((df['label'].values + 1)[train_index])), Variable( torch.LongTensor((df['label'].values + 1)[test_index])) # prep temporal graph data k = args.period trainGs, testGs = [dgls[i] for i in train_index], [dgls[i] for i in test_index] trainGs, testGs = [dgl.batch([u[i] for u in trainGs]) for i in range(k)], \ [dgl.batch([u[i] for u in testGs]) for i in range(k)] train_inputs, test_inputs = [inputs[i] for i in train_index ], [inputs[i] for i in test_index] train_inputs, test_inputs = [torch.FloatTensor(np.concatenate([inp[i] for inp in train_inputs])) for i in range(k)],\ [torch.FloatTensor(np.concatenate([inp[i] for inp in test_inputs])) for i in range(k)] train_xav, test_xav = [xav[i] for i in train_index], [xav[i] for i in test_index] train_xav, test_xav = [ torch.FloatTensor(np.concatenate([inp[i] for inp in train_xav])) for i in range(k) ], [ torch.FloatTensor(np.concatenate([inp[i] for inp in test_xav])) for i in range(k) ]
def test_simple_pool(): ctx = F.ctx() g = dgl.DGLGraph(nx.path_graph(15)) sum_pool = nn.SumPooling() avg_pool = nn.AvgPooling() max_pool = nn.MaxPooling() sort_pool = nn.SortPooling(10) # k = 10 print(sum_pool, avg_pool, max_pool, sort_pool) # test#1: basic h0 = F.randn((g.number_of_nodes(), 5)) sum_pool = sum_pool.to(ctx) avg_pool = avg_pool.to(ctx) max_pool = max_pool.to(ctx) sort_pool = sort_pool.to(ctx) h1 = sum_pool(g, h0) assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0)) h1 = avg_pool(g, h0) assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0)) h1 = max_pool(g, h0) assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0)) h1 = sort_pool(g, h0) assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.dim() == 2 # test#2: batched graph g_ = dgl.DGLGraph(nx.path_graph(5)) bg = dgl.batch([g, g_, g, g_, g]) h0 = F.randn((bg.number_of_nodes(), 5)) h1 = sum_pool(bg, h0) truth = th.stack([ F.sum(h0[:15], 0), F.sum(h0[15:20], 0), F.sum(h0[20:35], 0), F.sum(h0[35:40], 0), F.sum(h0[40:55], 0) ], 0) assert F.allclose(h1, truth) h1 = avg_pool(bg, h0) truth = th.stack([ F.mean(h0[:15], 0), F.mean(h0[15:20], 0), F.mean(h0[20:35], 0), F.mean(h0[35:40], 0), F.mean(h0[40:55], 0) ], 0) assert F.allclose(h1, truth) h1 = max_pool(bg, h0) truth = th.stack([ F.max(h0[:15], 0), F.max(h0[15:20], 0), F.max(h0[20:35], 0), F.max(h0[35:40], 0), F.max(h0[40:55], 0) ], 0) assert F.allclose(h1, truth) h1 = sort_pool(bg, h0) assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.dim() == 2
def batcher_dev(batch): graph_q, label = zip(*batch) graph_q = dgl.batch(graph_q) return graph_q, torch.LongTensor(label)
def main(): logger.info('Reading data and extra features...') fp_files = [] if opt.fp is None else opt.fp.split(',') fp_extra, y_array, name_array = dataloader.load(opt.input, opt.target, fp_files) smiles_list = [name.split()[0] for name in name_array] logger.info('Generating molecular graphs with %s...' % opt.graph) if opt.graph == 'rdk': graph_list, feats_list = smi2dgl(smiles_list) elif opt.graph == 'msd': msd_list = ['%s.msd' % base64.b64encode(smiles.encode()).decode() for smiles in smiles_list] graph_list, feats_list = msd2dgl(msd_list, '../data/msdfiles.zip') else: raise logger.info('Node feature example: (size=%d) %s' % (len(feats_list[0][0]), ','.join(map(str, feats_list[0][0])))) logger.info('Extra graph feature example: (size=%d) %s' % (len(fp_extra[0]), ','.join(map(str, fp_extra[0])))) logger.info('Output example: (size=%d) %s' % (len(y_array[0]), ','.join(map(str, y_array[0])))) if fp_extra.shape[-1] > 0: logger.info('Normalizing extra graph features...') scaler = preprocessing.Scaler() scaler.fit(fp_extra) scaler.save(opt.output + '/scale.txt') fp_extra = scaler.transform(fp_extra) logger.info('Selecting data...') selector = preprocessing.Selector(smiles_list) if opt.part is not None: logger.info('Loading partition file %s' % opt.part) selector.load(opt.part) else: logger.warning('Partition file not provided. Using auto-partition instead') selector.partition(0.8, 0.2) device = torch.device('cuda:0') # batched data for training set data_list = [[data[i] for i in np.where(selector.train_index)[0]] for data in (graph_list, y_array, feats_list, fp_extra, name_array, smiles_list)] n_batch, (graphs_batch, y_batch, feats_node_batch, feats_extra_batch, names_batch) = \ preprocessing.separate_batches(data_list[:-1], opt.batch, data_list[-1]) bg_batch_train = [dgl.batch(graphs).to(device) for graphs in graphs_batch] y_batch_train = [torch.tensor(y, dtype=torch.float32, device=device) for y in y_batch] feats_node_batch_train = [torch.tensor(np.concatenate(feats_node), dtype=torch.float32, device=device) for feats_node in feats_node_batch] feats_extra_batch_train = [torch.tensor(feats_extra, dtype=torch.float32, device=device) for feats_extra in feats_extra_batch] # for plot y_train_array = np.concatenate(y_batch) names_train = np.concatenate(names_batch) # data for validation set graphs, y, feats_node, feats_extra, names_valid = \ [[data[i] for i in np.where(selector.valid_index)[0]] for data in (graph_list, y_array, feats_list, fp_extra, name_array)] bg_valid, y_valid, feats_node_valid, feats_extra_valid = ( dgl.batch(graphs).to(device), torch.tensor(y, dtype=torch.float32, device=device), torch.tensor(np.concatenate(feats_node), dtype=torch.float32, device=device), torch.tensor(feats_extra, dtype=torch.float32, device=device), ) # for plot y_valid_array = y_array[selector.valid_index] logger.info('Training size = %d, Validation size = %d' % (len(y_train_array), len(y_valid_array))) logger.info('Batches = %d, Batch size ~= %d' % (n_batch, opt.batch)) in_feats_node = feats_list[0].shape[-1] in_feats_extra = fp_extra[0].shape[-1] n_heads = list(map(int, opt.head.split(','))) logger.info('Building network...') logger.info('Conv layers = %s' % n_heads) logger.info('Learning rate = %s' % opt.lr) logger.info('L2 penalty = %f' % opt.l2) model = GATModel(in_feats_node, opt.embed, n_head_list=n_heads, extra_feats=in_feats_extra) model.cuda() print(model) for name, param in model.named_parameters(): print(name, param.data.shape) header = 'Step MaxRE(t) Loss MeaSquE MeaSigE MeaUnsE MaxRelE Acc2% Acc5% Acc10%'.split() logger.info('%-8s %8s %8s %8s %8s %8s %8s %8s %8s %8s' % tuple(header)) optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.l2) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.lrsteps, gamma=opt.lrgamma) for epoch in range(opt.epoch): model.train() if (epoch + 1) % opt.check == 0: pred_train = [None] * n_batch for ib in np.random.permutation(n_batch): optimizer.zero_grad() pred = model(bg_batch_train[ib], feats_node_batch_train[ib], feats_extra_batch_train[ib]) loss = F.mse_loss(pred, y_batch_train[ib]) loss.backward() optimizer.step() if (epoch + 1) % opt.check == 0: pred_train[ib] = pred.detach().cpu().numpy() scheduler.step() if (epoch + 1) % opt.check == 0: model.eval() pred_train = np.concatenate(pred_train) pred_valid = model(bg_valid, feats_node_valid, feats_extra_valid).detach().cpu().numpy() err_line = '%-8i %8.1f %8.2e %8.2e %8.1f %8.1f %8.1f %8.1f %8.1f %8.1f' % ( epoch + 1, metrics.max_relative_error(y_train_array, pred_train) * 100, metrics.mean_squared_error(y_train_array, pred_train), metrics.mean_squared_error(y_valid_array, pred_valid), metrics.mean_signed_error(y_valid_array, pred_valid) * 100, metrics.mean_unsigned_error(y_valid_array, pred_valid) * 100, metrics.max_relative_error(y_valid_array, pred_valid) * 100, metrics.accuracy(y_valid_array, pred_valid, 0.02) * 100, metrics.accuracy(y_valid_array, pred_valid, 0.05) * 100, metrics.accuracy(y_valid_array, pred_valid, 0.10) * 100) logger.info(err_line) torch.save(model, opt.output + '/model.pt') visualizer = visualize.LinearVisualizer(y_train_array.reshape(-1), pred_train.reshape(-1), names_train, 'train') visualizer.append(y_valid_array.reshape(-1), pred_valid.reshape(-1), names_valid, 'valid') visualizer.dump(opt.output + '/fit.txt') visualizer.dump_bad_molecules(opt.output + '/error-0.10.txt', 'valid', threshold=0.1) visualizer.dump_bad_molecules(opt.output + '/error-0.20.txt', 'valid', threshold=0.2) visualizer.scatter_yy(savefig=opt.output + '/error-train.png', annotate_threshold=0.1, marker='x', lw=0.2, s=5) visualizer.hist_error(savefig=opt.output + '/error-hist.png', label='valid', histtype='step', bins=50) plt.show()