def get_index_mapper(graph, next_graph): """ :param graph: :param next_graph: :return: """ if type(graph) == dgl.BatchedDGLGraph and type(next_graph) == dgl.BatchedDGLGraph: graphs = dgl.unbatch(graph) next_graphs = dgl.unbatch(next_graph) cur_idx = [] next_idx = [] curr_num_nodes = 0 next_num_nodes = 0 for g, ng in zip(graphs, next_graphs): _curr_num_nodes = len(get_filtered_node_index_by_type(g, NODE_ALLY)) _next_num_nodes = len(get_filtered_node_index_by_type(ng, NODE_ALLY)) ci, ni = _get_index_mapper_list(g, ng, curr_num_nodes, next_num_nodes) cur_idx.extend(ci) next_idx.extend(ni) curr_num_nodes += _curr_num_nodes next_num_nodes += _next_num_nodes else: cur_idx, next_idx = _get_index_mapper_list(graph, next_graph, 0, 0) return cur_idx, next_idx
def make_step_modifications(action, steps): selected_actions.append(action) with torch.no_grad(): g1, mp1, g2, mp2 = steps glist2 = dgl.unbatch(g2) glist1 = dgl.unbatch(g1) new_steps = [] action = torch.reshape(action, (-1, len(action_list))) for k in range(len(mp2)): actions = torch.nonzero(torch.round(action[k])) new_step = [glist2[k], mp2[k]] if actions.size()[0] == 0: #if no action is above 0.5, just select the max actions = torch.tensor([torch.argmax(action[k], 0)]) new_step = [glist2[k], mp2[k]] for i, act in enumerate(actions): new_step = action_list[act](new_step) _g, _mp = new_step new_step_k = [glist1[k], mp1[k], _g, _mp] new_steps.append([new_step_k, 0]) new_steps = my_collate(new_steps) return new_steps
def test_batch_unbatch_frame(idtype): """Test module of node/edge frames of batched/unbatched DGLGraphs. Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475. """ t1 = tree1(idtype) t2 = tree2(idtype) N1 = t1.number_of_nodes() E1 = t1.number_of_edges() N2 = t2.number_of_nodes() E2 = t2.number_of_edges() D = 10 t1.ndata['h'] = F.randn((N1, D)) t1.edata['h'] = F.randn((E1, D)) t2.ndata['h'] = F.randn((N2, D)) t2.edata['h'] = F.randn((E2, D)) b1 = dgl.batch([t1, t2]) b2 = dgl.batch([t2]) b1.ndata['h'][:N1] = F.zeros((N1, D)) b1.edata['h'][:E1] = F.zeros((E1, D)) b2.ndata['h'][:N2] = F.zeros((N2, D)) b2.edata['h'][:E2] = F.zeros((E2, D)) assert not F.allclose(t1.ndata['h'], F.zeros((N1, D))) assert not F.allclose(t1.edata['h'], F.zeros((E1, D))) assert not F.allclose(t2.ndata['h'], F.zeros((N2, D))) assert not F.allclose(t2.edata['h'], F.zeros((E2, D))) g1, g2 = dgl.unbatch(b1) _g2, = dgl.unbatch(b2) assert F.allclose(g1.ndata['h'], F.zeros((N1, D))) assert F.allclose(g1.edata['h'], F.zeros((E1, D))) assert F.allclose(g2.ndata['h'], t2.ndata['h']) assert F.allclose(g2.edata['h'], t2.edata['h']) assert F.allclose(_g2.ndata['h'], F.zeros((N2, D))) assert F.allclose(_g2.edata['h'], F.zeros((E2, D)))
def directed_tree_loss(self, all_edus, l_trees_graph, r_trees_graph, trees_graph, roots): h_cat, doc_embed, doc_lengths = self.edu_embed_model(all_edus) sample_node_embeds = self.split_node_embed(h_cat, doc_lengths) batch, seq_len, _ = h_cat.shape h_cat_nopadding = th.cat(sample_node_embeds) trees_graph.ndata['h'] = h_cat_nopadding trees_graph.ndata['ch_h'] = th.zeros_like(h_cat_nopadding) trees_graph.register_message_func(self.message_func) trees_graph.register_reduce_func( lambda x: self.reduce_func(x, doc_embed, batch, seq_len)) trees_graph.pull(trees_graph.nodes()) del trees_graph.ndata['h'] del trees_graph.ndata['ch_h'] left_adj, right_adj = [], [] for i, (l_trees_subg, r_trees_subg) in enumerate( zip(dgl.unbatch(l_trees_graph), dgl.unbatch(r_trees_graph))): left_adj.append(l_trees_subg.reverse() \ .adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \ .to_dense() \ .unsqueeze(0)) right_adj.append(r_trees_subg.reverse().adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \ .to_dense() \ .unsqueeze(0)) left_adj = th.cat(left_adj) right_adj = th.cat(right_adj) compat_matrix = self.get_compat_matrix(h_cat) root_scores = self.root_clf(h_cat).view(h_cat.shape[0], -1) self.total_score += self.logistic_loss(compat_matrix, (left_adj, right_adj), (root_scores, roots)) / batch return self.total_score
def directed_tree_loss(self, all_edus, l_trees_graph, r_trees_graph, trees_graph, roots): h_cat, doc_embed, doc_lengths = self.edu_embed_model(all_edus) sample_node_embeds = self.split_node_embed(h_cat, doc_lengths) batch, seq_len, _ = h_cat.shape left_adj, right_adj = [], [] for i, (l_trees_subg, r_trees_subg) in enumerate( zip(dgl.unbatch(l_trees_graph), dgl.unbatch(r_trees_graph))): left_adj.append(l_trees_subg.reverse() \ .adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \ .to_dense() \ .unsqueeze(0)) right_adj.append(r_trees_subg.reverse().adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \ .to_dense() \ .unsqueeze(0)) self.total_score = 0 left_adj = th.cat(left_adj) right_adj = th.cat(right_adj) compat_matrix = self.get_compat_matrix(h_cat) root_scores = self.root_clf(h_cat).view(h_cat.shape[0], -1) self.total_score += self.logistic_loss(compat_matrix, (left_adj, right_adj), (root_scores, roots)) / batch return self.total_score
def forward(self, bg, bg_out_hr,feature_name='energy',out_name='neu_energy'): output_gr = [] with bg.local_scope(): with bg_out_hr.local_scope(): graph_list = dgl.unbatch(bg) graph_list_out_hr = dgl.unbatch(bg_out_hr) for ig in range(len(graph_list)) : g = graph_list[ig] g_out_hr = graph_list_out_hr[ig] data = g.ndata[feature_name] data = torch.reshape(data, (data.shape[0],) ) b_factors = g.ndata['broadcast'] out = torch.repeat_interleave(data,b_factors,dim=0) g_out_hr.ndata[out_name] = out[:, None] output_gr.append(g_out_hr ) return dgl.batch(output_gr)
def forward(self, g1, g2, mode='pairs'): ''' mode: 'pairs' expect paired graphs, same for g1 and g2. 'retrieval' g1 is just one graph and computes the distance against all graphs in g2 ''' g1_list = dgl.unbatch(g1) if len(g1_list) > 1: for i, g in enumerate(g1_list): g.gdata = {} g.gdata['std'] = g1.gdata['std'][i] g2_list = dgl.unbatch(g2) for i, g in enumerate(g2_list): g.gdata = {} g.gdata['std'] = g2.gdata['std'][i] d = [] for i in range(len(g2_list)): if mode == 'pairs': d_aux = self.soft_hausdorff(g1_list[i], g2_list[i]) elif mode == 'retrieval': query = g1_list[0] d_aux = self.soft_hausdorff(query, g2_list[i]) else: raise NameError(mode + ' not implemented!') d.append(d_aux) d = torch.stack(d) return d
def forward(self, graphs, need_weights=False): encoding = torch.cat( [self.encoder.encode(graph) for graph in dgl.unbatch(graphs)], dim=0) embedding = self.node_embedder(graphs.ndata['atomic'].type(torch.long)) if self.concatenate_encoding: graphs.ndata['h'] = torch.cat((encoding, embedding.squeeze()), dim=-1) else: graphs.ndata['h'] = encoding + embedding.squeeze() batch = [] for g in dgl.unbatch(graphs): batch.append(g.ndata['h']) h = torch.nn.utils.rnn.pad_sequence(batch) attentions_ = [] for block in self.blocks: h, att_ = block(h) if need_weights: attentions_.append(att_) truncated = [ h[:num_nodes, i, :] for i, num_nodes in enumerate(graphs.batch_num_nodes()) ] h = torch.cat(truncated, dim=0) if need_weights: return h, attentions_ return h
def forward(self, batch): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ #----------utils function--------------- def InitS(tree): tree.ndata['s'] = tree.ndata['e'].mean(dim=0).repeat(tree.number_of_nodes(), 1) return tree def updateS(tree, state): assert state.dim() == 1 tree.ndata['s'] = state.repeat(tree.number_of_nodes(), 1) return tree def extractS(batchTree): # [dmodel] --> [[dmodel]] --> [tree, dmodel] --> [tree, 1, dmodel] s_list = [tree.ndata.pop('s')[0].unsqueeze(0) for tree in dgl.unbatch(batchTree)] return th.cat(s_list, dim=0).unsqueeze(1) def extractH(batchTree): # [nodes, dmodel] --> [nodes, dmodel]--> [max_nodes, dmodel]--> [tree*_max_nodes, dmodel] --> [tree, max_nodes, dmodel] h_list = [tree.ndata.pop('h') for tree in dgl.unbatch(batchTree)] max_nodes = max([h.size(0) for h in h_list]) h_list = [th.cat([h, th.zeros([max_nodes-h.size(0), h.size(1)]).to(self.device)], dim=0).unsqueeze(0) for h in h_list] return th.cat(h_list, dim=0) #----------------------------------------- g = batch.graph # feed embedding embeds = self.embedding(batch.wordid * batch.mask) g.ndata['c'] = th.zeros((g.number_of_nodes(), 2, self.dmodel)).to(self.device) g.ndata['e'] = embeds*batch.mask.float().unsqueeze(-1) g.ndata['h'] = embeds*batch.mask.float().unsqueeze(-1) g = dgl.batch([InitS(gg) for gg in dgl.unbatch(g)]) # propagate for i in range(self.T_step): g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) dgl.prop_nodes_topo(g) States = self.cell.updateGlobalVec(extractS(g), extractH(g) ) g = dgl.batch([updateS(tree, state) for (tree, state) in zip(dgl.unbatch(g), States)]) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def forward(self, bg, bg_u): output_gr = [] with bg.local_scope(): with bg_u.local_scope(): graph_list = dgl.unbatch(bg) graph_list_u = dgl.unbatch(bg_u) for ig in range(len(graph_list)): g = graph_list[ig] g_u = graph_list_u[ig] #print('----- start filling -------') n_unpooled_node = g.ndata['parent_node'][0] selected_nodes = g.ndata['_ID'][:, None] #a pooled_node_features = g.ndata['energy'] #b # print('selected_nodes shape : ', selected_nodes.shape) # print('pooled_node_features : ', pooled_node_features.shape) expanded_node = selected_nodes.expand_as( pooled_node_features) #c expanded_node = expanded_node.to(dev) pooled_node_features = pooled_node_features.to(dev) x = torch.zeros(n_unpooled_node, self.dim, device=dev) #x.to(dev) x.scatter_(0, expanded_node, pooled_node_features) #print('----- end filling -------') g_new = dgl.DGLGraph() g_new.add_nodes(n_unpooled_node) src, dst = g_u.edges() g_new.add_edges(src, dst) g_new.ndata['energy'] = x g_new.ndata['parent_node'] = g_u.ndata['parent_node'][ 0] * torch.ones([g_new.number_of_nodes()], dtype=torch.int) g_new.ndata['_ID'] = torch.tensor(g_u.ndata['_ID'], dtype=torch.int64) # print('Output node energy shape : ', g_new.ndata['energy'].shape) output_gr.append(g_new) return dgl.batch(output_gr)
def forward(self, fact_batch_graph, img_batch_graph, sem_batch_graph): fact_graphs = dgl.unbatch(fact_batch_graph) img_graphs = dgl.unbatch(img_batch_graph) sem_graphs = dgl.unbatch(sem_batch_graph) num_graph = len(fact_graphs) new_fact_graphs = [] for i in range(num_graph): fact_graph = fact_graphs[i] img_graph = img_graphs[i] sem_graph = sem_graphs[i] fact_graph = self.gcn(fact_graph, img_graph, sem_graph) new_fact_graphs.append(fact_graph) return dgl.batch(new_fact_graphs)
def test_batch_propagate(): t1 = tree1() t2 = tree2() bg = dgl.batch([t1, t2]) bg.register_message_func(lambda edges: {'m': edges.src['h']}) bg.register_reduce_func(lambda nodes: {'h': F.sum(nodes.mailbox['m'], 1)}) # get leaves. order = [] # step 1 u = [3, 4, 2 + 5, 0 + 5] v = [1, 1, 4 + 5, 4 + 5] order.append((u, v)) # step 2 u = [1, 2, 4 + 5, 3 + 5] v = [0, 0, 1 + 5, 1 + 5] order.append((u, v)) bg.prop_edges(order) t1, t2 = dgl.unbatch(bg) assert F.asnumpy(t1.ndata['h'][0]) == 9 assert F.asnumpy(t2.ndata['h'][1]) == 5
def test_batch_propagate(idtype): t1 = tree1(idtype) t2 = tree2(idtype) bg = dgl.batch([t1, t2]) _mfunc = lambda edges: {'m': edges.src['h']} _rfunc = lambda nodes: {'h': F.sum(nodes.mailbox['m'], 1)} # get leaves. order = [] # step 1 u = [3, 4, 2 + 5, 0 + 5] v = [1, 1, 4 + 5, 4 + 5] order.append((u, v)) # step 2 u = [1, 2, 4 + 5, 3 + 5] v = [0, 0, 1 + 5, 1 + 5] order.append((u, v)) bg.prop_edges(order, _mfunc, _rfunc) t1, t2 = dgl.unbatch(bg) assert F.asnumpy(t1.ndata['h'][0]) == 9 assert F.asnumpy(t2.ndata['h'][1]) == 5
def run_episode(self): _, _, _ = self.env.reset() graph_batch = dgl.unbatch(self.env.graph) last_accuracy = [] for graph in graph_batch: A = nx.to_numpy_array(graph.to_networkx()) signals = np.transpose(self.env.signals) action = np.zeros(signals.shape) action[signals > 0.5] = 1 accuracy = [(action.T == self.env.world).mean()] for _ in range(self.T - 1): last_action = deepcopy(action) # pdb.set_trace() neighbor_average = A.dot(last_action) / A.sum(axis=1)[:, None] action = np.zeros(last_action.shape) action[neighbor_average > 0.5] = 1 accuracy.append((action.T == self.env.world).mean()) benchmark = norm.cdf(0.5, loc=0, scale=np.sqrt(self.env.var / graph.number_of_nodes())) if self.normalize: last_accuracy.append(accuracy / benchmark) else: last_accuracy.append(accuracy) last_accuracy = np.array(last_accuracy) return np.mean(last_accuracy, axis=0), np.std( last_accuracy, axis=0) / np.sqrt(last_accuracy.shape[0])
def predict_ppo(self, non_fixed_variables, last_visited, temperature): """ Given the state related to a node in the CP search, compute the PPO prediction :param non_fixed_variables: variables that are not yet fixed (i.e., must_visit) :param last_visited: the last city visited :param temperature: the softmax temperature for favoring the exploration :return: a vector of probabilities of selecting an action """ self.update_graph_state(non_fixed_variables, last_visited) y_pred = self.model(self.input_graph, graph_pooling=False) out = dgl.unbatch(y_pred)[0] action_probs = out.ndata["n_feat"].squeeze(-1) available_tensor = torch.zeros([self.n_city]) available_tensor[non_fixed_variables] = 1 action_probs = action_probs + torch.abs(torch.min(action_probs)) action_probs = action_probs - torch.max( action_probs * available_tensor) y_pred_list = ActorCritic.masked_softmax(action_probs, available_tensor, dim=0, temperature=temperature) y_pred_list = y_pred_list.data.cpu().numpy().flatten() return y_pred_list
def train(self, x, y): """ Compute the loss between (f(x) and y) :param x: the input :param y: the true value of y :return: the loss """ self.model.train() graph, _ = list(zip(*x)) graph_batch = dgl.batch(graph) y_pred = self.model(graph_batch, graph_pooling=False) y_pred = torch.stack([g.ndata["n_feat"] for g in dgl.unbatch(y_pred)]).squeeze(dim=2) y_tensor = torch.FloatTensor(np.array(y)) if self.args.mode == 'gpu': y_tensor = y_tensor.contiguous().cuda() loss = F.smooth_l1_loss(y_pred, y_tensor) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item()
def forward(self, data, efeat): if self.gnn_type == "rgcn": # efeat is etypes; data is node features x = self.gnn_object(data, efeat) elif self.gnn_type == "gat": # data is node features x = self.gnn_object(data) elif self.gnn_type == "mpnn": # data is node features; efeat is edge features x = self.gnn_object(data, efeat) if not self.robot_node_indexes: indexes = [] n_nodes = 0 unbatched = dgl.unbatch(self.g) for g in unbatched: indexes.append(n_nodes+self.grid_nodes) n_nodes += g.number_of_nodes() else: indexes = self.robot_node_indexes logits = torch.squeeze(x, 1).to(device=data.device) output = logits[indexes].to(device=data.device) # print("filtering by indexes", output.shape) outputS = output.shape nfeats = (1+len(self.central_grid_nodes))*outputS[1] # print("nfeats", nfeats) newShape = [(outputS[0]*outputS[1])//nfeats, nfeats] # print("final shape", newShape) output = output.view(newShape) return output
def forward(self, cand_batch, mol_tree_batch): cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \ mol2dgl(cand_batch, mol_tree_batch) n_samples = len(cand_graphs) cand_graphs = batch(cand_graphs) cand_line_graph = line_graph(cand_graphs, no_backtracking=True) n_nodes = len(cand_graphs.nodes) n_edges = len(cand_graphs.edges) cand_graphs = self.run(cand_graphs, cand_line_graph, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes, mol_tree_batch) cand_graphs = unbatch(cand_graphs) g_repr = torch.stack( [g.get_n_repr()['h'].mean(0) for g in cand_graphs], 0) self.n_samples_total += n_samples self.n_nodes_total += n_nodes self.n_edges_total += n_edges self.n_passes += 1 return g_repr
def extractS(batchTree): # [dmodel] --> [[dmodel]] --> [tree, dmodel] --> [tree, 1, dmodel] s_list = [ tree.ndata.pop('s')[0].unsqueeze(0) for tree in dgl.unbatch(batchTree) ] return th.cat(s_list, dim=0).unsqueeze(1)
def forward(self, g, features, etypes): self.g = g h = features self.g.edata['norm'] = self.g.edata['norm'].to(device=features.device) for layer in self.gnn_layers: h = layer(self.g, h, etypes) base_index = 0 batch_number = 0 unbatched = dgl.unbatch(self.g) gnn_output = torch.Tensor(size=(len(unbatched), self.gnn_output)).to( device=features.device) for g in unbatched: num_nodes = g.number_of_nodes() gnn_output[batch_number, :] = h[ base_index, :] # Output is just the room's node # output[batch_number, :] = logits[base_index:base_index+num_nodes, :].mean(dim=0) # Output is the average of all nodes base_index += num_nodes batch_number += 1 value = self.value_layers(gnn_output) advantage = self.advantage_layers(gnn_output) advAverage = torch.mean(advantage, dim=1, keepdim=True) Q = value + advantage - advAverage return Q
def decode(self, *, sample: Dict[str, Any], prefix: str = "", metadata: Optional[Dict[str, Any]] = None) -> None: batched_graph = sample["typed_dgl_graph"].graph graphs = unbatch(batched_graph) start = 0 total_number_of_nodes = 0 bounds = [] numpy_indexes = sample["indexes"].indexes.cpu().numpy() for graph in graphs: total_number_of_nodes += graph.number_of_nodes() end = bisect_right(numpy_indexes, total_number_of_nodes - 1) bounds.append((start, end)) start = end for (start, end), path in zip(bounds, sample["metadata"]): path_probas = sample["forward"][start:end, 1] path_indexes = sample["indexes"].offsets[start:end] predictions = path_indexes[path_probas.argsort(descending=True)] if metadata is not None and "metadata" in metadata: metadata["metadata"][path] = { index: ["%.8f" % (2**proba)] for index, proba in zip(path_indexes.tolist(), path_probas.tolist()) } predictions += 1 print("%s%s %s" % (prefix, path, " ".join(map(str, predictions.numpy()))))
def translate_gt_graph_to_adj(gt_graph): gt_adjs = [] gt_g_list = dgl.unbatch(gt_graph) for gt_g in gt_g_list: gt_list = [] gt_ids = [] n_node = gt_g.number_of_nodes() srt, dst = gt_g.edges() srt, dst = srt.detach().cpu().numpy(), dst.detach().cpu().numpy() edge_factor = gt_g.edata['feat'].detach().cpu().numpy() assert srt.shape[0] == edge_factor.shape[0] for edge_id in set(edge_factor): ## operate in the matrix form org_g = np.zeros((n_node, n_node)) edge_factor_edge_id = np.zeros_like(edge_factor) idx = np.where(edge_factor == edge_id)[0] edge_factor_edge_id[idx] = 1.0 org_g[srt, dst] = edge_factor_edge_id gt_list.append(org_g) gt_ids.append(edge_id) gt_adjs.append((gt_list, gt_ids)) return gt_adjs
def evaluate(self, state_for_action, state_for_value, action, available_tensor): """ Evaluating an action wrt. the current policy :param state_for_action: State used to compute the actor output :param state_for_value: State used to compute the critic output. Although it it the same as the state_for_action, it is not the same object :param action: the action that is evaluaed :param available_tensor: The actions that are possible. :return: the log-probabilities of the action, the critic evaluation of the state, the entropy value """ if self.args.mode == "gpu": available_tensor = available_tensor.cuda() out = self.action_layer(state_for_action, graph_pooling=False) out = [x.ndata["n_feat"] for x in dgl.unbatch(out)] action_probs = torch.stack(out).squeeze(-1) action_probs = action_probs + torch.abs(torch.min(action_probs, 1, keepdim=True)[0]) action_probs = action_probs - torch.max(action_probs * available_tensor, 1, keepdim=True)[0] action_probs = self.masked_softmax(action_probs, available_tensor, dim=1) dist = Categorical(action_probs) action_log_probs = dist.log_prob(action) dist_entropy = dist.entropy() state_value = self.value_layer(state_for_value, graph_pooling=True) return action_log_probs, torch.squeeze(state_value), dist_entropy
def forward(self, graph): embedding_output = self.embeddings(graph.ndata['input_ids'], graph.ndata['position_ids'], graph.ndata['segment_ids']) graph.ndata.pop('input_ids') graph.ndata.pop('position_ids') graph.ndata.pop('segment_ids') hidden_size = embedding_output.size(-1) embedding_output = embedding_output.view(-1, hidden_size) graph.ndata['h'] = embedding_output graph = self.encoder(graph) g_list = dgl.unbatch(graph) pooled_output = [] for g in g_list: pooled_output.append(g.ndata['h'][0]) pooled_output = torch.stack(pooled_output, 0) pooled_output = self.pooler(pooled_output) return graph, pooled_output
def predict(net, loader, dataset, batch_size=50, naf_obj=None, progbar=None): predicted_frames = [] predicted_roles = [] net.eval() with torch.no_grad(): for gs in loader: frame_labels, role_labels, \ frame_chance, role_chance = net.label(gs) node_offset = 0 for g in dgl.unbatch(gs): sentence = dataset.conllu(g) for i, token in enumerate(sentence): token.ROLE = role_labels[i + node_offset] token.pROLE = role_chance[i + node_offset] token.FRAME = frame_labels[i + node_offset] token.pFRAME = frame_chance[i + node_offset] node_offset += len(g) # match the predicate and roles by some simple graph traversal # rules frames, orphans = make_frames(sentence) if naf_obj: write_frames_to_naf(naf_obj, frames, sentence) if progbar: progbar.next(batch_size) if progbar: progbar.finish()
def getMaskForBatch(subgraph): future_index = 0 indexes = [] for g in dgl.unbatch(subgraph): indexes.append(future_index) future_index += g.number_of_nodes() return indexes
def act(self, graph_state, available_tensor): """ Perform an action following the probabilities outputed by the current actor :param graph_state: the current state :param available_tensor: [0,1]-vector of available actions :return: the action selection, its log-probability, and its probability """ if self.args.mode == "gpu": available_tensor = available_tensor.cuda() batched_graph = dgl.batch([graph_state, ]) self.action_layer.eval() with torch.no_grad(): out = self.action_layer(batched_graph, graph_pooling=False) out = dgl.unbatch(out)[0] action_probs = out.ndata["n_feat"].squeeze(-1) # Doing post-processing on the output to have numerically stable probabilities given that a mask is used action_probs = action_probs + torch.abs(torch.min(action_probs)) action_probs = action_probs - torch.max(action_probs * available_tensor) action_probs = self.masked_softmax(action_probs, available_tensor, dim=0) dist = Categorical(action_probs) action = dist.sample() return action, dist.log_prob(action), action_probs
def clone(self): assert self.P is None, "clone not implemented for field P." if isinstance(self.X, list): X = [x.detach().clone() for x in self.X] M = { key: value.detach().clone() for key, value in self.masks.items() } elif isinstance(self.X, torch.Tensor): X = self.X.detach().clone() elif isinstance(self.X, dgl.DGLGraph): X = dgl.batch(dgl.unbatch(self.X)) M = {mask: X.ndata[mask] for mask in self.masks.keys()} else: assert False, "unhandled type to clone: {}".format(type(self.X)) return MiniBatch( X, self.Y.detach().clone(), # tensor copy.copy(self.lengths), # list M, # dict of tensors None, # ?? copy.deepcopy(self.data) if self.data is not None else None, # dist copy.copy(self.ids) if self.ids is not None else None, # list )
def getMaskForBatch(subgraph): first_node_index_in_the_next_graph = 0 indexes = [] for g in dgl.unbatch(subgraph): indexes.append(first_node_index_in_the_next_graph) first_node_index_in_the_next_graph += g.number_of_nodes() return indexes
def test_all(self, dataset: AllDataset, output_dir: str = "test_result"): if not os.path.exists(output_dir): os.makedirs(output_dir) print( f"make new dir {os.path.abspath(output_dir)}, and write files into it." ) else: print(f'output dir {os.path.abspath(output_dir)} exists !') self.load() self.eval() data_loader = GraphDataLoader(dataset.test, collate_fn=collate, batch_size=10, shuffle=False, drop_last=False) start_time = time.time() file_name_index = 1 for i, (bhg, info) in enumerate(data_loader): batch_size = len(info) self.forward(bhg) for idi, (cg, cd) in enumerate(zip(dgl.unbatch(bhg), info)): track_pd_list = graph_and_info_to_df_list(cg, cd) # todo # pd.set_option('display.max_columns', 10000) # print(track_pd_list[0]) for i_df, df in enumerate(track_pd_list): df.to_csv(os.path.join(output_dir, str(file_name_index) + ".csv"), index=False) file_name_index += 1 self.train() print( f"test time is :{time.time() - start_time:6.2f} s | num_samples : {len(dataset.test)}" )