def mergeGraph(graph_a, graph_b, train_data): g = DGLGraph() g.add_nodes(len(graph_a.id2idx) + len(graph_b.id2idx)) g.add_edges(graph_a.edge_src, graph_a.edge_dst) g.add_edges(graph_a.edge_dst, graph_a.edge_src) #offset g.add_edges(list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_src)), list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_dst))) g.add_edges(list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_dst)), list(map(lambda x: x + len(graph_a.id2idx), graph_b.edge_src))) print(train_data.shape) print(g.number_of_edges()) for i in range(train_data.shape[0]): g.add_edge(train_data[i, 0], train_data[i, 1] + len(graph_a.id2idx)) g.add_edge(train_data[i, 1] + len(graph_a.id2idx), train_data[i, 0]) num_edges = g.number_of_edges() g.ndata['features'] = torch.cat([ torch.FloatTensor(graph_a.features), torch.FloatTensor(graph_b.features) ], 0).cuda() return g
def mol2dgl(smiles_batch): n_nodes = 0 graph_list = [] for smiles in smiles_batch: atom_feature_list = [] bond_feature_list = [] bond_source_feature_list = [] graph = DGLGraph() mol = get_mol(smiles) for atom in mol.GetAtoms(): graph.add_node(atom.GetIdx()) atom_feature_list.append(atom_features(atom)) for bond in mol.GetBonds(): begin_idx = bond.GetBeginAtom().GetIdx() end_idx = bond.GetEndAtom().GetIdx() features = bond_features(bond) graph.add_edge(begin_idx, end_idx) bond_feature_list.append(features) # set up the reverse direction graph.add_edge(end_idx, begin_idx) bond_feature_list.append(features) atom_x = torch.stack(atom_feature_list) graph.set_n_repr({'x': atom_x}) if len(bond_feature_list) > 0: bond_x = torch.stack(bond_feature_list) graph.set_e_repr({ 'x': bond_x, 'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_() }) graph_list.append(graph) return graph_list
def test_pull_0deg(): g = DGLGraph() g.add_nodes(2) g.add_edge(0, 1) def _message(edges): return {'m' : edges.src['h']} def _reduce(nodes): return {'h' : nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)} def _apply(nodes): return {'h' : nodes.data['h'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + F.zeros(shape, dtype, ctx) g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) g.set_n_initializer(_init2, 'h') # test#1: pull both 0deg and non-0deg nodes old = F.randn((2, 5)) g.ndata['h'] = old g.pull([0, 1]) new = g.ndata.pop('h') # 0deg check: initialized with the func and got applied assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32)) # non-0deg check assert F.allclose(new[1], F.sum(old, 0) * 2) # test#2: pull only 0deg node old = F.randn((2, 5)) g.ndata['h'] = old g.pull(0) new = g.ndata.pop('h') # 0deg check: fallback to apply assert F.allclose(new[0], 2*old[0]) # non-0deg check: not touched assert F.allclose(new[1], old[1])
class DataEntry: def __init__(self, datset, num_nodes, features, edges, target): self.dataset = datset self.num_nodes = num_nodes self.target = target self.graph = DGLGraph() self.features = torch.FloatTensor(features) self.graph.add_nodes(self.num_nodes, data={'features': self.features}) for s, _type, t in edges: etype_number = self.dataset.get_edge_type_number(_type) self.graph.add_edge( s, t, data={'etype': torch.LongTensor([etype_number])})
def batch_transform_bert_fever(inst, bert_max_len, bert_tokenizer): g = DGLGraph() g.add_nodes(len(inst['node'])) question = inst['question'] for i in range(len(inst['node'])): for j in range(len(inst['node'])): if i == j: continue g.add_edge(i, j) for i, node in enumerate(inst['node']): if node['label'] == 1: g.nodes[i].data['label'] = torch.tensor(1).unsqueeze(0).type( torch.FloatTensor) elif node['label'] == 0: g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0).type( torch.FloatTensor) node['name'] = str(node['name']) title = node['name'].replace('_', ' ') context = node['context'] sent_num = node['sent_num'] encoding_inputs, encoding_masks, encoding_ids = encode_sequence_fever( question, title, context, bert_max_len, bert_tokenizer, sent_num) g.nodes[i].data['encoding'] = encoding_inputs.unsqueeze(0) g.nodes[i].data['encoding_mask'] = encoding_masks.unsqueeze(0) g.nodes[i].data['segment_id'] = encoding_ids.unsqueeze(0) """if inst['label'] == 'REFUTES': label = 0 elif inst['label'] == 'NOT ENOUGH INFO': label = 1 elif inst['label'] == 'SUPPORTS': label = 2 else: print('Problem!') """ if inst['label'] == 'false': label = 0 elif inst['label'] == 'half-true': label = 1 elif inst['label'] == 'true': label = 2 else: print('Problem!') return g, inst['qid'], label
def process(mol: Mol, device: torch.device, **kwargs) -> GATData: n = mol.GetNumAtoms() + 1 graph = DGLGraph() graph.add_nodes(n) graph.add_edges(graph.nodes(), graph.nodes()) graph.add_edges(range(1, n), 0) # graph.add_edges(0, range(1, n)) for e in mol.GetBonds(): u, v = e.GetBeginAtomIdx(), e.GetEndAtomIdx() graph.add_edge(u + 1, v + 1) graph.add_edge(v + 1, u + 1) v, m = feature.mol_feature(mol) vec = torch.cat([torch.zeros((1, m)), v]).to(device) return GATData(n, graph, vec)
def process(mol: Mol, device: torch.device, **kwargs): n = mol.GetNumAtoms() + 1 graph = DGLGraph() graph.add_nodes(n) graph.add_edges(graph.nodes(), graph.nodes()) graph.add_edges(range(1, n), 0) # graph.add_edges(0, range(1, n)) for e in mol.GetBonds(): u, v = e.GetBeginAtomIdx(), e.GetEndAtomIdx() graph.add_edge(u + 1, v + 1) graph.add_edge(v + 1, u + 1) adj = graph.adjacency_matrix(transpose=False).to_dense() v, m = feature.mol_feature(mol) vec = torch.cat([torch.zeros((1, m)), v]).to(device) return ChebNetData(n, adj, vec)
def test_recv_0deg_newfld(): # test recv with 0deg nodes; the reducer also creates a new field g = DGLGraph() g.add_nodes(2) g.add_edge(0, 1) def _message(edges): return {'m': edges.src['h']} def _reduce(nodes): return {'h1': nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)} def _apply(nodes): return {'h1': nodes.data['h1'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + F.zeros(shape, dtype=dtype, ctx=ctx) g.register_message_func(_message) g.register_reduce_func(_reduce) g.register_apply_node_func(_apply) # test#1: recv both 0deg and non-0deg nodes old = F.randn((2, 5)) g.set_n_initializer(_init2, 'h1') g.ndata['h'] = old g.send((0, 1)) g.recv([0, 1]) new = g.ndata.pop('h1') # 0deg check: initialized with the func and got applied assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32)) # non-0deg check assert F.allclose(new[1], F.sum(old, 0) * 2) # test#2: recv only 0deg node old = F.randn((2, 5)) g.ndata['h'] = old g.ndata['h1'] = F.full((2, 5), -1, F.int64) # this is necessary g.send((0, 1)) g.recv(0) new = g.ndata.pop('h1') # 0deg check: fallback to apply assert F.allclose(new[0], F.full_1d(5, -2, F.int64)) # non-0deg check: not changed assert F.allclose(new[1], F.full_1d(5, -1, F.int64))
def create_g(file_path, use_cuda=False): # print(file_path) npz = np.load(file_path, allow_pickle=True) labels = npz['labels'] fts_nodes = npz['fts_node'] edge_type = npz['edge_type'].tolist() edge_norm = npz['edge_norm'].tolist() edges = npz['edges'] # print(npz['nums'],len(labels),len(fts_nodes)) num_nodes = len(labels) # print(np.dtype(fts_nodes)) g = DGLGraph() g.add_nodes(num_nodes) edge_type = np.array(edge_type) edge_norm = np.array(edge_norm) for i in edges: #BASIC EDGES g.add_edge(i[0].item(), i[1].item()) edge_type = torch.from_numpy(edge_type) edge_norm = torch.from_numpy(edge_norm).unsqueeze(1) edge_type = edge_type.long() edge_norm = edge_norm.float() # fts_nodes = fts_nodes.astype(float) # fts_nodes = fts_nodes.astype(int) fts_nodes = torch.from_numpy(fts_nodes) fts_nodes = fts_nodes.long() labels = torch.from_numpy(labels) if (use_cuda): labels = labels.cuda() edge_type = edge_type.cuda() edge_norm = edge_norm.cuda() fts_nodes = fts_nodes.cuda() g.edata.update({'rel_type': edge_type, 'norm': edge_norm}) g.ndata['id'] = fts_nodes return [g, labels, fts_nodes]
def test_update_all_0deg(): # test#1 g = DGLGraph() g.add_nodes(5) g.add_edge(1, 0) g.add_edge(2, 0) g.add_edge(3, 0) g.add_edge(4, 0) def _message(edges): return {'m': edges.src['h']} def _reduce(nodes): return {'h': nodes.data['h'] + F.sum(nodes.mailbox['m'], 1)} def _apply(nodes): return {'h': nodes.data['h'] * 2} def _init2(shape, dtype, ctx, ids): return 2 + F.zeros(shape, dtype, ctx) g.set_n_initializer(_init2, 'h') old_repr = F.randn((5, 5)) g.ndata['h'] = old_repr g.update_all(_message, _reduce, _apply) new_repr = g.ndata['h'] # the first row of the new_repr should be the sum of all the node # features; while the 0-deg nodes should be initialized by the # initializer and applied with UDF. assert F.allclose(new_repr[1:], 2 * (2 + F.zeros((4, 5)))) assert F.allclose(new_repr[0], 2 * F.sum(old_repr, 0)) # test#2: graph with no edge g = DGLGraph() g.add_nodes(5) g.set_n_initializer(_init2, 'h') g.ndata['h'] = old_repr g.update_all(_message, _reduce, _apply) new_repr = g.ndata['h'] # should fallback to apply assert F.allclose(new_repr, 2 * old_repr)
def generate_graph(grad=False): g = DGLGraph() g.add_nodes(10) # 10 nodes # create a graph where 0 is the source and 9 is the sink # 17 edges for i in range(1, 9): g.add_edge(0, i) g.add_edge(i, 9) # add a back flow from 9 to 0 g.add_edge(9, 0) ncol = F.randn((10, D)) ecol = F.randn((17, D)) if grad: ncol = F.attach_grad(ncol) ecol = F.attach_grad(ecol) g.ndata['h'] = ncol g.edata['w'] = ecol g.set_n_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer) return g
def test_dynamic_addition(): N = 3 D = 1 g = DGLGraph() g = g.to(F.ctx()) # Test node addition g.add_nodes(N) g.ndata.update({'h1': F.randn((N, D)), 'h2': F.randn((N, D))}) g.add_nodes(3) assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3 # Test edge addition g.add_edge(0, 1) g.add_edge(1, 0) g.edata.update({'h1': F.randn((2, D)), 'h2': F.randn((2, D))}) assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2 g.add_edges([0, 2], [2, 0]) g.edata['h1'] = F.randn((4, D)) assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4 g.add_edge(1, 2) g.edges[4].data['h1'] = F.randn((1, D)) assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5 # test add edge with part of the features g.add_edge(2, 1, {'h1': F.randn((1, D))}) assert len(g.edata['h1']) == len(g.edata['h2'])
def reduced_graph(graph:DGLGraph,center_node:int,paths:dict,node_attr_name,edge_attr_name): """ reduced graph into a simpler graph with only one center node :param graph: the graph need to be reduced :param center_node the reference node :param paths: the traversal path of nodes using BFS :return: new_graph """ new_graph = DGLGraph() new_graph.add_nodes(num=graph.number_of_nodes()) new_graph.ndata[node_attr_name] = graph.ndata[node_attr_name] for node, path in paths.items(): path_weight = torch.tensor([1.]) for index,edge in enumerate(path): path_weight *= graph.edata[edge_attr_name][graph.edge_id(edge[0],edge[1])]*math.exp(-index) new_graph.add_edge(center_node,node,data={edge_attr_name:path_weight}) new_graph.add_edge(node, center_node, data={edge_attr_name: path_weight}) new_graph.add_edges(new_graph.nodes(), new_graph.nodes(), data={edge_attr_name: torch.ones(new_graph.number_of_nodes(), )}) new_graph.edata[edge_attr_name] = new_graph.edata[edge_attr_name].softmax(dim=0) pass return new_graph
def process(self, mol: Mol, atom_map): n = mol.GetNumAtoms() + 1 graph = DGLGraph() graph.add_nodes(n) graph.add_edges(graph.nodes(), graph.nodes()) graph.add_edges(range(1, n), 0) # graph.add_edges(0, range(1, n)) for e in mol.GetBonds(): u, v = e.GetBeginAtomIdx(), e.GetEndAtomIdx() graph.add_edge(u + 1, v + 1) graph.add_edge(v + 1, u + 1) feature = torch.cat([ torch.zeros((1, self.feature_dim), device=self.device), # node 0 torch.nn.functional.one_hot(torch.tensor( [atom_map[u.GetAtomicNum()] for u in mol.GetAtoms()], device=self.device), num_classes=self.feature_dim).to( torch.float) ]) return GCNData(n, graph, feature)
def get_graph_from_smile(molecule_smile): """ Method that constructs a molecular graph with nodes being the atoms and bonds being the edges. :param molecule_smile: SMILE sequence :return: DGL graph object, Node features and Edge features """ G = DGLGraph() molecule = Chem.MolFromSmiles(molecule_smile) features = rdDesc.GetFeatureInvariants(molecule) stereo = Chem.FindMolChiralCenters(molecule) chiral_centers = [0] * molecule.GetNumAtoms() for i in stereo: chiral_centers[i[0]] = i[1] G.add_nodes(molecule.GetNumAtoms()) node_features = [] edge_features = [] for i in range(molecule.GetNumAtoms()): atom_i = molecule.GetAtomWithIdx(i) atom_i_features = get_atom_features(atom_i, chiral_centers[i], features[i]) node_features.append(atom_i_features) for j in range(molecule.GetNumAtoms()): bond_ij = molecule.GetBondBetweenAtoms(i, j) if bond_ij is not None: G.add_edge(i, j) bond_features_ij = get_bond_features(bond_ij) edge_features.append(bond_features_ij) G.ndata['x'] = np.array(node_features) G.edata['w'] = np.array(edge_features) return G
def vectorize_qanta(ex, tokenizer, device, istrain, max_seq_length=64): bert_model.eval() t_id = ex['id'] text = ex['text'] positive_entity = ex['pos_et'] negative_entities = ex['neg_ets'] ## In QANTA setting, we limit the maximum sentences as three ( for efficient training and evaluation) num_edges = 3 g = DGLGraph() question_node_list = list() candidate_node_list = list() first_sent_tokens = list() first_sent_masks = list() question_tokens = list() question_masks = list() input_ids, input_mask = text_tokenize(text, tokenizer, max_seq_length) question_tokens.append(input_ids) question_masks.append(input_mask) node_sub_questions = list() for sup_q in ex['q_et']: sub_question = sup_q['text'] input_ids, input_mask = text_tokenize(sub_question, tokenizer, max_seq_length) question_tokens.append(input_ids) question_masks.append(input_mask) for et in sup_q['entity']: topic = et['et'] node_first_sent = et['first_sent'] if topic is None: continue question_node_list.append(topic) question_idx = len(question_tokens) - 1 node_sub_questions.append(question_idx) input_ids, input_mask = text_tokenize(node_first_sent, tokenizer, max_seq_length) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) candidate_node_list.append(normalize(positive_entity['et'])) input_ids, input_mask = text_tokenize(positive_entity['first_sent'], tokenizer, max_seq_length) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) node_sub_questions.append(0) for neg_et in negative_entities: candidate_node_list.append(normalize(neg_et['et'])) input_ids, input_mask = text_tokenize(neg_et['first_sent'], tokenizer, max_seq_length) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) node_sub_questions.append(0) num_nodes = len(question_node_list) + len(candidate_node_list) g.add_nodes(num_nodes) num_questions = len(question_tokens) ### combine question and first sentence all_tokens = question_tokens + first_sent_tokens all_masks = question_masks + first_sent_masks all_tensor = torch.LongTensor(all_tokens).to(device) all_masks_tensor = torch.LongTensor(all_masks).to(device) all_encodings = list() num_exs = 50 for iii in range(int(all_tensor.size(0) / num_exs)): encoding, _ = bert_model( all_tensor[iii * num_exs:(iii + 1) * num_exs], None, all_masks_tensor[iii * num_exs:(iii + 1) * num_exs]) encoding = encoding.detach().cpu() all_encodings.append(encoding) if all_tensor.size(0) % num_exs > 0: encoding, _ = bert_model( all_tensor[int(all_tensor.size(0) / num_exs) * num_exs:], None, all_masks_tensor[int(all_tensor.size(0) / num_exs) * num_exs:]) encoding = encoding.detach().cpu() all_encodings.append(encoding) all_encodings = torch.cat(all_encodings, dim=0) all_masks_tensor = all_masks_tensor.cpu() g.ndata['first_sent'] = all_encodings[num_questions:].cpu() g.ndata['first_sent_mask'] = all_masks_tensor[num_questions:].cpu().eq(0) for i in range(len(question_node_list)): sub_q_num = node_sub_questions[i] g.nodes[i].data['question'] = all_encodings[sub_q_num].unsqueeze(0) g.nodes[i].data['question_mask'] = all_masks_tensor[ sub_q_num].unsqueeze(0).eq(0) g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0) g.nodes[len( question_node_list)].data['question'] = all_encodings[0].unsqueeze(0) g.nodes[len(question_node_list)].data['question_mask'] = all_masks_tensor[ 0].unsqueeze(0).eq(0) g.nodes[len(question_node_list)].data['label'] = torch.tensor(1).unsqueeze( 0) #### for candidates, we only use the full question sentence for i in range( len(question_node_list) + 1, len(question_node_list) + len(candidate_node_list)): g.nodes[i].data['question'] = all_encodings[0].unsqueeze(0) g.nodes[i].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq( 0) g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0) for k_entity in positive_entity['evidence']: normalized_k_entity = normalize(k_entity) if normalized_k_entity in question_node_list: s_id = question_node_list.index(normalized_k_entity) g.add_edge(question_node_list.index(normalized_k_entity), len(question_node_list)) evidence_tokens = list() evidence_masks = list() evidence_ids = list() all_evidences = positive_entity['evidence'][k_entity] for evi_text in all_evidences[:num_edges]: input_ids, input_mask = text_tokenize(evi_text, tokenizer, max_seq_length) evidence_tokens.append(input_ids) evidence_masks.append(input_mask) evidence_tensor = torch.LongTensor(evidence_tokens) evidence_masks_tensor = torch.LongTensor(evidence_masks) edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_() edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_() egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1) edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor) edge_feature_masks[0, :len(evidence_tokens), :].copy_( evidence_masks_tensor) egde_sent_mask[0, :len(evidence_tokens)].fill_(0) g.edges[s_id, len(question_node_list)].data['evidence'] = edge_features g.edges[s_id, len(question_node_list )].data['evidence_mask'] = edge_feature_masks g.edges[s_id, len(question_node_list )].data['evidence_sent_mask'] = egde_sent_mask for neg_et in negative_entities: for k_entity in neg_et['evidence']: normalized_k_entity = normalize(k_entity) if normalized_k_entity in question_node_list: s_id = question_node_list.index(normalized_k_entity) t_id = len(question_node_list) + candidate_node_list.index( normalize(neg_et['et'])) g.add_edge(s_id, t_id) evidence_tokens = list() evidence_masks = list() evidence_ids = list() all_evidences = neg_et['evidence'][normalized_k_entity] for evi_text in all_evidences[:num_edges]: input_ids, input_mask = text_tokenize( evi_text, tokenizer, max_seq_length) evidence_tokens.append(input_ids) evidence_masks.append(input_mask) evidence_tensor = torch.LongTensor(evidence_tokens) evidence_masks_tensor = torch.LongTensor(evidence_masks) edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_() edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_() egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1) edge_features[0, :len(evidence_tokens), :].copy_( evidence_tensor) edge_feature_masks[0, :len(evidence_tokens), :].copy_( evidence_masks_tensor) egde_sent_mask[0, :len(evidence_tokens)].fill_(0) g.edges[s_id, t_id].data['evidence'] = edge_features g.edges[s_id, t_id].data['evidence_mask'] = edge_feature_masks g.edges[s_id, t_id].data['evidence_sent_mask'] = egde_sent_mask ### Batch the sentences and get BERT embeddings if 'evidence' in g.edata: evi = g.edata['evidence'].to(device) evi_mask = g.edata['evidence_mask'].to(device) batch_size, sent_max_len, word_max_len = evi.size(0), evi.size( 1), evi.size(2) evi = evi.view(batch_size * sent_max_len, word_max_len) evi_mask = evi_mask.view(batch_size * sent_max_len, word_max_len) all_encodings = list() num_exs = 50 for iii in range(int(evi.size(0) / num_exs)): encoding, _ = bert_model( evi[iii * num_exs:(iii + 1) * num_exs], None, evi_mask[iii * num_exs:(iii + 1) * num_exs]) encoding = encoding.detach().cpu() all_encodings.append(encoding) if evi.size(0) % num_exs > 0: encoding, _ = bert_model( evi[int(evi.size(0) / num_exs) * num_exs:], None, evi_mask[int(evi.size(0) / num_exs) * num_exs:]) encoding = encoding.detach().cpu() all_encodings.append(encoding) g.edata['evidence'] = torch.cat(all_encodings, dim=0).view(batch_size, sent_max_len, word_max_len, -1) g.edata['evidence_mask'] = g.edata['evidence_mask'].eq(0) return g
def test_send_multigraph(): g = DGLGraph(multigraph=True) g.add_nodes(3) g.add_edge(0, 1) g.add_edge(0, 1) g.add_edge(0, 1) g.add_edge(2, 1) def _message_a(edges): return {'a': edges.data['a']} def _message_b(edges): return {'a': edges.data['a'] * 3} def _reduce(nodes): return {'a': F.max(nodes.mailbox['a'], 1)} def answer(*args): return F.max(F.stack(args, 0), 0) # send by eid old_repr = F.randn((4, 5)) g.ndata['a'] = F.zeros((3, 5)) g.edata['a'] = old_repr g.send([0, 2], message_func=_message_a) g.recv(1, _reduce) new_repr = g.ndata['a'] assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[2])) g.ndata['a'] = F.zeros((3, 5)) g.edata['a'] = old_repr g.send([0, 2, 3], message_func=_message_a) g.recv(1, _reduce) new_repr = g.ndata['a'] assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])) # send on multigraph g.ndata['a'] = F.zeros((3, 5)) g.edata['a'] = old_repr g.send(([0, 2], [1, 1]), _message_a) g.recv(1, _reduce) new_repr = g.ndata['a'] assert F.allclose(new_repr[1], F.max(old_repr, 0)) # consecutive send and send_on g.ndata['a'] = F.zeros((3, 5)) g.edata['a'] = old_repr g.send((2, 1), _message_a) g.send([0, 1], message_func=_message_b) g.recv(1, _reduce) new_repr = g.ndata['a'] assert F.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3])) # consecutive send_on g.ndata['a'] = F.zeros((3, 5)) g.edata['a'] = old_repr g.send(0, message_func=_message_a) g.send(1, message_func=_message_b) g.recv(1, _reduce) new_repr = g.ndata['a'] assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3)) # send_and_recv_on g.ndata['a'] = F.zeros((3, 5)) g.edata['a'] = old_repr g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce) new_repr = g.ndata['a'] assert F.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])) assert F.allclose(new_repr[[0, 2]], F.zeros((2, 5)))
def get_batch_dist(filename, batch_size, gpu_rank): doc_dict = read_bert_features( '../../Recommenders/data/MINDlarge_train/news_token_bert_features.txt') f = open(filename, 'r').readlines() length = int(len(f) / 4) f = f[gpu_rank * length:(gpu_rank + 1) * length] batch_data = [] all_nodes = 0 i = 0 batch_graph = [] bert_max_len = 20 imp_index = [] batch_t = 0 while i < len(f): data = json.loads(f[i]) neg_list = data['neg_list'] neg_sample = np.random.choice(len(neg_list), 1, replace=False)[0] #ex['neg_list']=doc_dict[neg_list[neg_sample]] data['node'][-1]['context'] = doc_dict[neg_list[neg_sample]] # data['node'][-3]['label']=-2 all_nodes += len(data['node']) g = DGLGraph() g.add_nodes(len(data['node'])) # edge_dict = dict() # for edge in data['edges']: # e_start = edge['start'] # e_end = edge['end'] # idx = (e_start, e_end) # if idx not in edge_dict: # edge_dict[idx] =list() # # if edge['sent'] not in edge_dict[idx]: # # edge_dict[idx].append(edge['sent']) # for idx, context in edge_dict.items(): # start, end = idx # g.add_edge(start, end) for k in range(len(data['node']) - 2): for j in range(k + 1, len(data['node'])): #if i!=len(node)-2 and j!=len(node)_1: # data['edges'].append({'start': i, 'end': j}) # data['edges'].append({'start': j, 'end': i}) g.add_edge(k, j) g.add_edge(j, k) for k, node in enumerate(data['node']): context = node['context'] context = [int(x) for x in context] #evidence = list(set(node['evidence'])) encoding_inputs, encoding_masks, encoding_ids = encode_sequence_hotpot( context, bert_max_len) g.nodes[k].data['encoding'] = encoding_inputs.unsqueeze(0) g.nodes[k].data['encoding_mask'] = encoding_masks.unsqueeze(0) g.nodes[k].data['segment_id'] = encoding_ids.unsqueeze(0) #print('????',node) if node['label'] == 1: g.nodes[k].data['label'] = torch.tensor(1).unsqueeze(0).type( torch.FloatTensor) elif node['label'] == 0: g.nodes[k].data['label'] = torch.tensor(0).unsqueeze(0).type( torch.FloatTensor) elif k == len(data['node']) - 3: g.nodes[k].data['label'] = torch.tensor(-2).unsqueeze(0).type( torch.FloatTensor) else: g.nodes[k].data['label'] = torch.tensor(-1).unsqueeze(0).type( torch.FloatTensor) if batch_t >= batch_size: batch_t = 0 g = dgl.batch(batch_graph) batch_graph = [] all_nodes = 0 yield g, imp_index, torch.tensor([0] * len(imp_index)) imp_index = [] else: batch_t += 1 batch_graph.append(g) imp_index.append(data['imp_id']) #print('???',data['imp_id'],imp_index) i += 1
def batch_transform_bert_hotpot(inst, bert_max_len, bert_tokenizer): g = DGLGraph() g.add_nodes(len(inst['node'])) question = inst['question'] for i, node in enumerate(inst['node']): inst['node'][i]['evidence'] = list() #### concatenate all edge sentences edge_dict = dict() for edge in inst['edge']: e_start = edge['start'] e_end = edge['end'] idx = (e_start, e_end) if idx not in edge_dict: edge_dict[idx] = list() if edge['sent'] not in edge_dict[idx]: edge_dict[idx].append(edge['sent']) for idx, context in edge_dict.items(): start, end = idx g.add_edge(start, end) for sent in context: inst['node'][end]['evidence'].append(sent) for i, node in enumerate(inst['node']): context = node['context'] evidence = list(set(node['evidence'])) encoding_inputs, encoding_masks, encoding_ids, B_start = encode_sequence_hotpot( question, context, evidence, bert_max_len, bert_tokenizer) g.nodes[i].data['encoding'] = encoding_inputs.unsqueeze(0) g.nodes[i].data['encoding_mask'] = encoding_masks.unsqueeze(0) g.nodes[i].data['segment_id'] = encoding_ids.unsqueeze(0) g.nodes[i].data['B_start'] = B_start if node['is_ans'] == 1: g.nodes[i].data['label'] = torch.tensor(1).unsqueeze(0).type( torch.FloatTensor) elif node['is_ans'] == 0: g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0).type( torch.FloatTensor) else: g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0).type( torch.FloatTensor) spans = node['spans'] if node['is_ans'] != 1: g.nodes[i].data['span_label'] = torch.tensor(-1).unsqueeze(0).type( torch.FloatTensor) else: g.nodes[i].data['span_label'] = torch.tensor(0).unsqueeze(0).type( torch.FloatTensor) if len(spans) == 0: spans = [(-1, -1)] start_spans = torch.LongTensor([p[0] for p in spans]) end_spans = torch.LongTensor([p[1] for p in spans]) start_spans = start_spans + B_start end_spans = end_spans + B_start g.nodes[i].data['label_start'] = start_spans g.nodes[i].data['label_end'] = end_spans return g, inst['qid']
def vectorize_qanta(ex, model, istrain, max_seq_length=128): q_id = ex['id'] text = ex['text'] positive_entity = ex['pos_et'] negative_entities = ex['neg_ets'] ### Maximum 3 sentences per edge num_edges = 3 g = DGLGraph() question_node_list = list() candidate_node_list = list() first_sent_tokens = list() first_sent_masks = list() question_tokens = list() question_masks = list() input_ids, input_mask = text_tokenize(text, model.word_dict, max_seq_length) question_tokens.append(input_ids) question_masks.append(input_mask) node_sub_questions = list() for sup_q in ex['q_et']: sub_question = sup_q['text'] input_ids, input_mask = text_tokenize(word_tokenize(sub_question), model.word_dict, max_seq_length) question_tokens.append(input_ids) question_masks.append(input_mask) for et in sup_q['entity']: topic = et['et'] node_first_sent = et['first_sent'] if topic is None: continue question_node_list.append(topic) question_idx = len(question_tokens) - 1 node_sub_questions.append(question_idx) input_ids, input_mask = text_tokenize(node_first_sent, model.word_dict, max_seq_length) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) candidate_node_list.append(normalize(positive_entity['et'])) input_ids, input_mask = text_tokenize(positive_entity['first_sent'], model.word_dict, max_seq_length) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) node_sub_questions.append(0) for neg_et in negative_entities: input_ids, input_mask = text_tokenize(neg_et['first_sent'], model.word_dict, max_seq_length) candidate_node_list.append(normalize(neg_et['et'])) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) node_sub_questions.append(0) num_nodes = len(question_node_list) + len(candidate_node_list) g.add_nodes(num_nodes) num_questions = len(question_tokens) ### combine question and first sentence all_tokens = question_tokens + first_sent_tokens all_masks = question_masks + first_sent_masks all_tensor = torch.LongTensor(all_tokens) all_masks_tensor = torch.LongTensor(all_masks) #### add node features g.ndata['first_sent'] = all_tensor[num_questions:].cpu() g.ndata['first_sent_mask'] = all_masks_tensor[num_questions:].eq(0) for i in range(len(question_node_list)): sub_q_num = node_sub_questions[i] g.nodes[i].data['question'] = all_tensor[sub_q_num].unsqueeze(0) g.nodes[i].data['question_mask'] = all_masks_tensor[sub_q_num].unsqueeze(0).eq(0) g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0) g.nodes[len(question_node_list)].data['question'] = all_tensor[0].unsqueeze(0) g.nodes[len(question_node_list)].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq(0) g.nodes[len(question_node_list)].data['label'] = torch.tensor(1).unsqueeze(0) #### for candidates, we only use the full question sentence for i in range(len(question_node_list) + 1, len(question_node_list) + len(candidate_node_list)): g.nodes[i].data['question'] = all_tensor[0].unsqueeze(0) g.nodes[i].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq(0) g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0) #### add postive edges for k_entity in positive_entity['evidence']: normalized_k_entity = normalize(k_entity) if normalized_k_entity in question_node_list: s_id = question_node_list.index(normalized_k_entity) g.add_edge(question_node_list.index(normalized_k_entity), len(question_node_list)) evidence_tokens = list() evidence_masks = list() all_evidences = positive_entity['evidence'][k_entity] for evi_text in all_evidences[:num_edges]: input_ids, input_mask = text_tokenize(evi_text, model.word_dict, max_seq_length) evidence_tokens.append(input_ids) evidence_masks.append(input_mask) evidence_tensor = torch.LongTensor(evidence_tokens) evidence_masks_tensor = torch.LongTensor(evidence_masks) edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_() edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_() egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1) edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor) edge_feature_masks[0, :len(evidence_tokens), :].copy_(evidence_masks_tensor) egde_sent_mask[0, :len(evidence_tokens)].fill_(0) g.edges[s_id, len(question_node_list)].data['evidence'] = edge_features g.edges[s_id, len(question_node_list)].data['evidence_mask'] = edge_feature_masks.eq(0) g.edges[s_id, len(question_node_list)].data['evidence_sent_mask'] = egde_sent_mask for neg_et in negative_entities: #### for k_entity in neg_et['evidence']: normalized_k_entity = normalize(k_entity) if normalized_k_entity in question_node_list: s_id = question_node_list.index(normalized_k_entity) t_id = len(question_node_list) + candidate_node_list.index(normalize(neg_et['et'])) g.add_edge(s_id, t_id) evidence_tokens = list() evidence_masks = list() all_evidences = neg_et['evidence'][normalized_k_entity] for evi_text in all_evidences[:num_edges]: input_ids, input_mask = text_tokenize(evi_text, model.word_dict, max_seq_length) evidence_tokens.append(input_ids) evidence_masks.append(input_mask) evidence_tensor = torch.LongTensor(evidence_tokens) evidence_masks_tensor = torch.LongTensor(evidence_masks) edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_() edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_() egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1) edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor) edge_feature_masks[0, :len(evidence_tokens), :].copy_(evidence_masks_tensor) egde_sent_mask[0, :len(evidence_tokens)].fill_(0) g.edges[s_id, t_id].data['evidence'] = edge_features g.edges[s_id, t_id].data['evidence_mask'] = edge_feature_masks.eq(0) g.edges[s_id, t_id].data['evidence_sent_mask'] = egde_sent_mask return g
class MoleculeEnv(object): """MDP environment for generating molecules. Parameters ---------- atom_types : list E.g. ['C', 'N'] bond_types : list E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC] """ def __init__(self, atom_types, bond_types): super(MoleculeEnv, self).__init__() self.atom_types = atom_types self.bond_types = bond_types self.atom_type_to_id = dict() self.bond_type_to_id = dict() for id, a_type in enumerate(atom_types): self.atom_type_to_id[a_type] = id for id, b_type in enumerate(bond_types): self.bond_type_to_id[b_type] = id def get_decision_sequence(self, mol, atom_order): """Extract a decision sequence with which DGMG can generate the molecule with a specified atom order. Parameters ---------- mol : Chem.rdchem.Mol atom_order : list Specifies a mapping between the original atom indices and the new atom indices. In particular, atom_order[i] is re-labeled as i. Returns ------- decisions : list decisions[i] is a 2-tuple (i, j) - If i = 0, j specifies either the type of the atom to add self.atom_types[j] or termination with j = len(self.atom_types) - If i = 1, j specifies either the type of the bond to add self.bond_types[j] or termination with j = len(self.bond_types) - If i = 2, j specifies the destination atom id for the bond to add. With the formulation of DGMG, j must be created before the decision. """ decisions = [] old2new = dict() for new_id, old_id in enumerate(atom_order): atom = mol.GetAtomWithIdx(old_id) a_type = atom.GetSymbol() decisions.append((0, self.atom_type_to_id[a_type])) for bond in atom.GetBonds(): u = bond.GetBeginAtomIdx() v = bond.GetEndAtomIdx() if v == old_id: u, v = v, u if v in old2new: decisions.append( (1, self.bond_type_to_id[bond.GetBondType()])) decisions.append((2, old2new[v])) decisions.append((1, len(self.bond_types))) old2new[old_id] = new_id decisions.append((0, len(self.atom_types))) return decisions def reset(self, rdkit_mol=False): """Setup for generating a new molecule Parameters ---------- rdkit_mol : bool Whether to keep a Chem.rdchem.Mol object so that we know what molecule is being generated """ self.dgl_graph = DGLGraph() # If there are some features for nodes and edges, # zero tensors will be set for those of new nodes and edges. self.dgl_graph.set_n_initializer(dgl.frame.zero_initializer) self.dgl_graph.set_e_initializer(dgl.frame.zero_initializer) self.mol = None if rdkit_mol: # RWMol is a molecule class that is intended to be edited. self.mol = Chem.RWMol(Chem.MolFromSmiles('')) def num_atoms(self): """Get the number of atoms for the current molecule. Returns ------- int """ return self.dgl_graph.number_of_nodes() def add_atom(self, type): """Add an atom of the specified type. Parameters ---------- type : int Should be in the range of [0, len(self.atom_types) - 1] """ self.dgl_graph.add_nodes(1) if self.mol is not None: self.mol.AddAtom(Chem.Atom(self.atom_types[type])) def add_bond(self, u, v, type, bi_direction=True): """Add a bond of the specified type between atom u and v. Parameters ---------- u : int Index for the first atom v : int Index for the second atom type : int Index for the bond type bi_direction : bool Whether to add edges for both directions in the DGLGraph. If not, we will only add the edge (u, v). """ if bi_direction: self.dgl_graph.add_edges([u, v], [v, u]) else: self.dgl_graph.add_edge(u, v) if self.mol is not None: self.mol.AddBond(u, v, self.bond_types[type]) def get_current_smiles(self): """Get the generated molecule in SMILES Returns ------- s : str SMILES """ assert self.mol is not None, 'Expect a Chem.rdchem.Mol object initialized.' s = Chem.MolToSmiles(self.mol) return s
def vectorize_trivia(ex, tokenizer, device, istrain, max_seq_length=64): bert_model.eval() t_id = ex['id'] text = ex['text'] positive_entity = ex['pos_et'] negative_entities = ex['neg_ets'] num_edges = 5 #? g = DGLGraph() #instance question_node_list = list() candidate_node_list = list() first_sent_tokens = list() first_sent_masks = list() question_tokens = list() question_masks = list() input_ids, input_mask = text_tokenize(text, tokenizer, max_seq_length) question_tokens.append(input_ids) question_masks.append(input_mask) node_sub_questions = list() ### Since Trivia QA question has much fewer entities, we also incorporate IR retrieved pages ### as additional entities linked to candidate nodes (So that we get some more edge evidence sentences) for sup_q in ex['q_et']: #each question entity sub_question = sup_q['text'] input_ids, input_mask = text_tokenize( sub_question, tokenizer, max_seq_length) #utils text tokenize?id and mask question_tokens.append(input_ids) question_masks.append(input_mask) for et in sup_q['entity']: topic = et['et'] node_first_sent = et['first_sent'] if topic is None: continue question_node_list.append(topic) question_idx = len(question_tokens) - 1 node_sub_questions.append(question_idx) input_ids, input_mask = text_tokenize( node_first_sent, tokenizer, max_seq_length) #utils text tokenize?id and mask first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) candidate_node_list.append(normalize(positive_entity['et'])) #normalized input_ids, input_mask = text_tokenize(positive_entity['first_sent'], tokenizer, max_seq_length) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) node_sub_questions.append(0) for neg_et in negative_entities: candidate_node_list.append(normalize(neg_et['et'])) input_ids, input_mask = text_tokenize(neg_et['first_sent'], tokenizer, max_seq_length) first_sent_tokens.append(input_ids) first_sent_masks.append(input_mask) node_sub_questions.append(0) #nodes num_nodes = len(question_node_list) + len(candidate_node_list) g.add_nodes(num_nodes) num_questions = len(question_tokens) ### combine question and first sentence all_tokens = question_tokens + first_sent_tokens all_masks = question_masks + first_sent_masks all_tensor = torch.LongTensor(all_tokens).to(device) #token as tensor all_masks_tensor = torch.LongTensor(all_masks).to( device) #masked token as tensor all_encodings = list() num_exs = 50 #Generate encoding? what is this for iii in range(int(all_tensor.size(0) / num_exs)): encoding, _ = bert_model(all_tensor[iii * num_exs:(iii + 1) * num_exs], None, all_masks_tensor[iii * num_exs:(iii + 1) * num_exs]) #bert_model()? encoding = encoding.detach().cpu() all_encodings.append(encoding) if all_tensor.size(0) % num_exs > 0: encoding, _ = bert_model( all_tensor[int(all_tensor.size(0) / num_exs) * num_exs:], None, all_masks_tensor[int(all_tensor.size(0) / num_exs) * num_exs:]) encoding = encoding.detach().cpu() all_encodings.append(encoding) all_encodings = torch.cat(all_encodings, dim=0) all_masks_tensor = all_masks_tensor.cpu() #saved for graph g.ndata['first_sent'] = all_encodings[num_questions:].cpu() g.ndata['first_sent_mask'] = all_masks_tensor[num_questions:].cpu().eq(0) for i in range(len(question_node_list)): sub_q_num = node_sub_questions[i] #distribute each node with its feature g.nodes[i].data['question'] = all_encodings[sub_q_num].unsqueeze(0) g.nodes[i].data['question_mask'] = all_masks_tensor[ sub_q_num].unsqueeze(0).eq(0) g.nodes[i].data['label'] = torch.tensor(-1).unsqueeze(0) g.nodes[len( question_node_list)].data['question'] = all_encodings[0].unsqueeze(0) g.nodes[len(question_node_list)].data['question_mask'] = all_masks_tensor[ 0].unsqueeze(0).eq(0) g.nodes[len(question_node_list)].data['label'] = torch.tensor(1).unsqueeze( 0) #### for candidates, we only use the full question sentence for i in range( len(question_node_list) + 1, len(question_node_list) + len(candidate_node_list)): g.nodes[i].data['question'] = all_encodings[0].unsqueeze(0) g.nodes[i].data['question_mask'] = all_masks_tensor[0].unsqueeze(0).eq( 0) g.nodes[i].data['label'] = torch.tensor(0).unsqueeze(0) #positive entities for k_entity in positive_entity['evidence']: normalized_k_entity = normalize(k_entity) #normalize if normalized_k_entity in question_node_list: s_id = question_node_list.index(normalized_k_entity) #list.index() g.add_edge(question_node_list.index(normalized_k_entity), len(question_node_list)) #edge between evidence_tokens = list() evidence_masks = list() evidence_ids = list() all_evidences = positive_entity['evidence'][k_entity] for evi_text in all_evidences[:num_edges]: input_ids, input_mask = text_tokenize(evi_text, tokenizer, max_seq_length) evidence_tokens.append(input_ids) evidence_masks.append(input_mask) evidence_tensor = torch.LongTensor(evidence_tokens) evidence_masks_tensor = torch.LongTensor(evidence_masks) edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_() edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_() egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1) edge_features[0, :len(evidence_tokens), :].copy_(evidence_tensor) edge_feature_masks[0, :len(evidence_tokens), :].copy_( evidence_masks_tensor) egde_sent_mask[0, :len(evidence_tokens)].fill_(0) g.edges[s_id, len(question_node_list)].data['evidence'] = edge_features g.edges[s_id, len(question_node_list)].data[ 'evidence_mask'] = edge_feature_masks #edge feature mask g.edges[s_id, len(question_node_list)].data[ 'evidence_sent_mask'] = egde_sent_mask #edge sentence mask #negative entities for neg_et in negative_entities: for k_entity in neg_et['evidence']: normalized_k_entity = normalize(k_entity) if normalized_k_entity in question_node_list: s_id = question_node_list.index(normalized_k_entity) t_id = len(question_node_list) + candidate_node_list.index( normalize(neg_et['et'])) g.add_edge(s_id, t_id) evidence_tokens = list() evidence_masks = list() evidence_ids = list() all_evidences = neg_et['evidence'][normalized_k_entity] for evi_text in all_evidences[:num_edges]: input_ids, input_mask = text_tokenize( evi_text, tokenizer, max_seq_length) evidence_tokens.append(input_ids) evidence_masks.append(input_mask) evidence_tensor = torch.LongTensor(evidence_tokens) evidence_masks_tensor = torch.LongTensor(evidence_masks) edge_features = torch.LongTensor(1, num_edges, max_seq_length).zero_() edge_feature_masks = torch.LongTensor(1, num_edges, max_seq_length).zero_() egde_sent_mask = torch.ByteTensor(1, num_edges).fill_(1) edge_features[0, :len(evidence_tokens), :].copy_( evidence_tensor) edge_feature_masks[0, :len(evidence_tokens), :].copy_( evidence_masks_tensor) egde_sent_mask[0, :len(evidence_tokens)].fill_(0) g.edges[s_id, t_id].data['evidence'] = edge_features g.edges[s_id, t_id].data['evidence_mask'] = edge_feature_masks g.edges[s_id, t_id].data['evidence_sent_mask'] = egde_sent_mask ### Batch the sentences and get BERT embeddings #get embeddings if 'evidence' in g.edata: evi = g.edata['evidence'].to(device) evi_mask = g.edata['evidence_mask'].to(device) batch_size, sent_max_len, word_max_len = evi.size(0), evi.size( 1), evi.size(2) evi = evi.view(batch_size * sent_max_len, word_max_len) evi_mask = evi_mask.view(batch_size * sent_max_len, word_max_len) all_encodings = list() num_exs = 50 for iii in range(int(evi.size(0) / num_exs)): encoding, _ = bert_model( evi[iii * num_exs:(iii + 1) * num_exs], None, evi_mask[iii * num_exs:(iii + 1) * num_exs]) encoding = encoding.detach().cpu() all_encodings.append(encoding) if evi.size(0) % num_exs > 0: encoding, _ = bert_model( evi[int(evi.size(0) / num_exs) * num_exs:], None, evi_mask[int(evi.size(0) / num_exs) * num_exs:]) encoding = encoding.detach().cpu() all_encodings.append(encoding) g.edata['evidence'] = torch.cat(all_encodings, dim=0).view(batch_size, sent_max_len, word_max_len, -1) g.edata['evidence_mask'] = g.edata['evidence_mask'].eq(0) #return graph return g
def mol2dgl(cand_batch, mol_tree_batch): cand_graphs = [] tree_mess_source_edges = [] # map these edges from trees to... tree_mess_target_edges = [] # these edges on candidate graphs tree_mess_target_nodes = [] n_nodes = 0 for mol, mol_tree, ctr_node_id in cand_batch: atom_feature_list = [] bond_feature_list = [] ctr_node = mol_tree.nodes[ctr_node_id] ctr_bid = ctr_node['idx'] g = DGLGraph() for atom in mol.GetAtoms(): atom_feature_list.append(atom_features(atom)) g.add_node(atom.GetIdx()) for bond in mol.GetBonds(): a1 = bond.GetBeginAtom() a2 = bond.GetEndAtom() begin_idx = a1.GetIdx() end_idx = a2.GetIdx() features = bond_features(bond) g.add_edge(begin_idx, end_idx) bond_feature_list.append(features) g.add_edge(end_idx, begin_idx) bond_feature_list.append(features) x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum() # Tree node ID in the batch x_bid = mol_tree.nodes[x_nid - 1]['idx'] if x_nid > 0 else -1 y_bid = mol_tree.nodes[y_nid - 1]['idx'] if y_nid > 0 else -1 if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid: if (x_bid, y_bid) in mol_tree_batch.edge_list: tree_mess_target_edges.append( (begin_idx + n_nodes, end_idx + n_nodes)) tree_mess_source_edges.append((x_bid, y_bid)) tree_mess_target_nodes.append(end_idx + n_nodes) if (y_bid, x_bid) in mol_tree_batch.edge_list: tree_mess_target_edges.append( (end_idx + n_nodes, begin_idx + n_nodes)) tree_mess_source_edges.append((y_bid, x_bid)) tree_mess_target_nodes.append(begin_idx + n_nodes) n_nodes += len(g.nodes) atom_x = torch.stack(atom_feature_list) g.set_n_repr({'x': atom_x}) if len(bond_feature_list) > 0: bond_x = torch.stack(bond_feature_list) g.set_e_repr({ 'x': bond_x, 'src_x': atom_x.new(len(bond_feature_list), ATOM_FDIM).zero_() }) cand_graphs.append(g) return cand_graphs, tree_mess_source_edges, tree_mess_target_edges, \ tree_mess_target_nodes
class D2GCN(nn.Module): def __init__(self, in_feat_dim, out_feat_dim): super(D2GCN, self).__init__() self.fedge = nn.Sequential( nn.Linear(in_feat_dim * 2, in_feat_dim // 64), nn.BatchNorm1d(in_feat_dim // 64), nn.Dropout(dropout), nn.LeakyReLU(), nn.Linear(in_feat_dim // 64, out_feat_dim), nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU()) if feature_drop: self.feat_drop = nn.Dropout(feature_drop) else: self.feat_drop = lambda x: x if att_drop: self.att_drop = nn.Dropout(att_drop) else: self.att_drop = lambda x: x self.attn_l = nn.Parameter(torch.Tensor(size=(1, out_feat_dim))) self.attn_r = nn.Parameter(torch.Tensor(size=(1, out_feat_dim))) self.relu = nn.LeakyReLU(alpha) self.softmax = edge_softmax self.fnode = nn.Sequential( nn.Linear(in_feat_dim + out_feat_dim, out_feat_dim // 64), nn.BatchNorm1d(out_feat_dim // 64), nn.Dropout(dropout), nn.LeakyReLU(), nn.Linear(out_feat_dim // 64, out_feat_dim), nn.BatchNorm1d(out_feat_dim), nn.Dropout(dropout), nn.ReLU()) nn.init.xavier_normal_(self.attn_l.data, gain=1.414) nn.init.xavier_normal_(self.attn_r.data, gain=1.414) def build_graph(self, num_nodes, device): self.g = DGLGraph() self.g.add_nodes(num_nodes) for i in range(0, num_nodes): for j in range(0, num_nodes): if i != j: self.g.add_edge(i, j) self.g.add_edge(j, i) self.g.to(device) self.g.register_message_func(self.send_source) self.g.register_reduce_func(self.simple_reduce) def send_source(self, edges): edge_feature = self.fedge.forward( torch.cat((edges.src["h"], edges.dst["h"]), dim=1)) msg = self.fnode.forward( torch.cat((edges.src["h"], edge_feature), dim=1)) m = torch.mul(msg, edges.data['a_drop']) return {"m": m} def simple_reduce(self, nodes): return {"h": torch.sum(nodes.mailbox['m'], dim=1) + nodes.data["h"]} def edge_attention(self, edges): a = self.relu(edges.src['a1'] + edges.dst['a2']) return {'a': a} def edge_softmax(self): att = self.softmax(self.g, self.g.edata.pop('a')) self.g.edata['a_drop'] = self.att_drop(att) def forward(self, n_feature): a1 = (n_feature * self.attn_l).sum(dim=-1).unsqueeze(-1) a2 = (n_feature * self.attn_r).sum(dim=-1).unsqueeze(-1) self.g.ndata.update({'h': n_feature, 'a1': a1, 'a2': a2}) self.g.apply_edges(self.edge_attention) self.edge_softmax() self.g.send(self.g.edges()) self.g.recv(self.g.nodes()) return self.g.ndata.pop('h')