def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, eps: float = 1e-12) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') index = broadcast(index, src, dim) if out is not None: dim_size = out.size(dim) else: if dim_size is None: dim_size = int(index.max()) + 1 size = src.size() size[dim] = dim_size max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, device=src.device) scatter_max(src, index, dim, max_value_per_index, dim_size)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element if out is not None: out = out.sub_(max_per_src_element).exp_() sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out, dim_size) return sum_per_index.add_(eps).log_().add_(max_value_per_index)
def message_step(self, node_network, edge_network, x, start, end, e): # Compute new node features edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0]) if self.hparams["aggregation"] == "sum": edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0]) elif self.hparams["aggregation"] == "max": edge_messages = scatter_max(e, end, dim=0, dim_size=x.shape[0])[0] elif self.hparams["aggregation"] == "sum_max": edge_messages = torch.cat( [ scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], scatter_add(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ) node_inputs = torch.cat([x, edge_messages], dim=-1) x_out = node_network(node_inputs) x_out += x # Compute new edge features edge_inputs = torch.cat([x[start], x[end], e], dim=-1) e_out = edge_network(edge_inputs) e_out += e return x_out, e_out
def message_step(self, x, start, end): edge_inputs = torch.cat([x[start], x[end]], dim=1) e = self.edge_network(edge_inputs) e = torch.sigmoid(e) if self.hparams["aggregation"] == "sum": messages = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]) elif self.hparams["aggregation"] == "max": messages = scatter_max(e * x[start], end, dim=0, dim_size=x.shape[0])[0] elif self.hparams["aggregation"] == "sum_max": messages = torch.cat( [ scatter_max(e * x[start], end, dim=0, dim_size=x.shape[0])[0], scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]), ], dim=-1, ) node_inputs = torch.cat([messages, x], dim=1) x_out = self.node_network(node_inputs) x_out += x return x_out
def get_aggregation(aggregation): """ Factory dictionary for aggregation depending on the hparams["aggregation"] """ aggregation_dict = { "sum": lambda e, end, x: scatter_add(e, end, dim=0, dim_size=x.shape[0]), "mean": lambda e, end, x: scatter_mean(e, end, dim=0, dim_size=x.shape[0]), "max": lambda e, end, x: scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], "sum_max": lambda e, end, x: torch.cat( [ scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], scatter_add(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ), "mean_sum": lambda e, end, x: torch.cat( [ scatter_mean(e, end, dim=0, dim_size=x.shape[0]), scatter_add(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ), "mean_max": lambda e, end, x: torch.cat( [ scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], scatter_mean(e, end, dim=0, dim_size=x.shape[0]), ], dim=-1, ), } return aggregation_dict[aggregation]
def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[col], edge_attr], dim=1) out = self.node_mlp_1(out) if self.aggregation == "mean": out = scatter_mean(out, row, dim=0, dim_size=x.size(0)) elif self.aggregation == "min": out, _ = scatter_min(out, row, dim=0, dim_size=x.size(0)) elif self.aggregation == "max": out, _ = scatter_max(out, row, dim=0, dim_size=x.size(0)) elif self.aggregation == "minmax": out = torch.cat([ scatter_min(out, row, dim=0, dim_size=x.size(0))[0], scatter_max(out, row, dim=0, dim_size=x.size(0))[0] ], dim=1) else: raise ValueError("Unknown aggregation type: {}".format(self.aggregation)) out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out)
def forward(self, x, x_privileged, action_masks): batch_size = x.size()[0] x, active_agents, (pitems, pmask) = self.latents(x, action_masks) if x.is_cuda: vin = torch.cuda.FloatTensor(batch_size, self.d_agent * self.hps.dff_ratio).fill_(0) else: vin = torch.zeros(batch_size, self.d_agent * self.hps.dff_ratio) scatter_max(x, index=active_agents.batch_index, dim=0, out=vin) if self.hps.use_privileged: mask1k = 1000.0 * pmask.float().unsqueeze(-1) pitems_max = (pitems - mask1k).max(dim=1).values pitems_max[pitems_max == -1000.0] = 0.0 pitems_avg = pitems.sum(dim=1) / torch.clamp_min( (~pmask).float().sum(dim=1), min=1).unsqueeze(-1) vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) values = self.value_head(vin).view(-1) logits = self.policy_head(x) logits = logits.masked_fill( action_masks.reshape(-1, self.naction)[active_agents.flat_index] == 0, float('-inf')) probs = F.softmax(logits, dim=1) probs = active_agents.pad(probs) return probs, values
def forward(self, x, edge_index): input_x = x x = self.node_encoder(x) # x = F.softmax(x, dim=-1) start, end = edge_index for i in range(self.hparams["n_graph_iters"]): x_initial = x messages_in, _ = scatter_max(x[start], end, dim=0, dim_size=x.shape[0]) messages_out, _ = scatter_max(x[end], start, dim=0, dim_size=x.shape[0]) message_stack = torch.stack((messages_in, messages_out), dim=0) messages, _ = torch.max(message_stack, 0) # messages = scatter_add(x[start], end, dim=0, dim_size=x.shape[0]) + scatter_add(x[end], start, dim=0, dim_size=x.shape[0]) node_inputs = torch.cat([x, messages], dim=-1) # node_inputs = F.softmax(node_inputs, dim=-1) x = self.node_network(node_inputs) x = x + x_initial edge_inputs = torch.cat([x[start], x[end]], dim=1) return self.edge_network(edge_inputs)
def forward(self, x, x_privileged, action_masks): batch_size = x.size()[0] x, indices, groups, (pitems, pmask) = self.latents(x, x_privileged, action_masks) if x.is_cuda: vin = torch.cuda.FloatTensor(batch_size, self.d_agent * self.hps.dff_ratio).fill_(0) else: vin = torch.zeros(batch_size, self.d_agent * self.hps.dff_ratio) scatter_max(x, index=groups, dim=0, out=vin) if self.hps.use_privileged: mask1k = 1000.0 * pmask.float().unsqueeze(-1) pitems_max = (pitems - mask1k).max(dim=1).values pitems_max[pitems_max == -1000.0] = 0.0 pitems_avg = pitems.sum(dim=1) / torch.clamp_min( (~pmask).float().sum(dim=1), min=1).unsqueeze(-1) vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) values = self.value_head(vin).view(-1) logits = self.policy_head(x) if x.is_cuda: padded_logits = torch.cuda.FloatTensor(batch_size * self.agents, self.naction).fill_(0) else: padded_logits = torch.zeros(batch_size * self.agents, self.naction) scatter_add(logits, index=indices, dim=0, out=padded_logits) padded_logits = padded_logits.view(batch_size, self.agents, self.naction) probs = F.softmax(padded_logits, dim=2) return probs, values
def pathat_layer(self, input, pathM, pathlens, eluF=True): N = input.size()[0] pathh = torch.mm(input, self.pathW) pathh = pathh+self.pathbias # h: N x out if not self.concat: # if the last layer pathlens = [2] pathfeat_all = None for pathlen_iter in pathlens: i = pathM[ pathlen_iter ]['indices'] v = pathM[ pathlen_iter ]['values'] featlen = pathh.shape[1] pathlen = v.shape[1] pathfeat = tuple( (pathh[v[:,i], :] for i in range(1,pathlen)) ) pathfeat = torch.cat(pathfeat, dim=1) pathfeat = pathfeat.view(-1,pathlen-1,featlen) pathfeat, _ = torch.max(pathfeat, dim=1) # seems max is better? #pathfeat = torch.mean(pathfeat, dim=1) # att_feat = torch.cat( (pathfeat, pathh[i[0,:],:]), dim=1 ).t() if pathlen_iter==2: path_att = self.leakyrelu(self.patha_2.mm(att_feat).squeeze()) else: path_att = self.leakyrelu(self.patha_3.mm(att_feat).squeeze()) # softmax of p_a -> p_a_e path_att = path_att - scatter_max(path_att, i[0,:], dim=0, dim_size=N)[0][i[0,:]] path_att = path_att.exp() path_att = path_att / (scatter_add(path_att, i[0,:], dim=0, dim_size=N)[i[0,:]] \ + torch.Tensor([9e-15]).cuda()) path_att = path_att.view(-1,1) path_att = self.dropout(path_att) # add dropout here of p_a_e w_pathfeat = torch.mul(pathfeat, path_att) h_path_prime = scatter_add(w_pathfeat, i[0,:], dim=0) # h_path_prime is the feature embedded from paths N*feat if pathfeat_all is None: pathfeat_all = h_path_prime else: pathfeat_all = torch.cat((pathfeat_all, h_path_prime), dim=0) if len(pathlens)==2: leni = torch.tensor(np.array(list(range(N))+list(range(N)))).cuda() att_feat = torch.cat( (pathfeat_all, pathh[leni,:]), dim=1 ).t() path_att = self.leakyrelu(self.lenAtt.mm(att_feat).squeeze()) # softmax of p_a -> p_a_e path_att = path_att - scatter_max(path_att, leni, dim=0, dim_size=N)[0][leni] path_att = path_att.exp() path_att = path_att / (scatter_add(path_att, leni, dim=0, dim_size=N)[leni] \ + torch.Tensor([9e-15]).cuda()) path_att = path_att.view(-1,1) # path_att = self.dropout(path_att) # add dropout here of p_a_e w_pathfeat = torch.mul(pathfeat_all, path_att) h_path_prime = scatter_add(w_pathfeat, leni, dim=0) if self.concat and eluF: return F.elu( h_path_prime ) else: return h_path_prime
def multi_hop(self, triple_prob, distance, head, tail, concept_label, triple_label, gamma=0.8, iteration=3, method="avg"): ''' triple_prob: bsz x L x mem_t distance: bsz x mem head, tail: bsz x mem_t concept_label: bsz x mem triple_label: bsz x mem_t Init binary vector with source concept == 1 and others 0 expand to size: bsz x L x mem ''' concept_probs = [] cpt_size = (triple_prob.size(0), triple_prob.size(1), distance.size(1)) init_mask = torch.zeros_like(distance).unsqueeze(1).expand(*cpt_size).to(distance.device).float() init_mask.masked_fill_((distance == 0).unsqueeze(1), 1) final_mask = init_mask.clone() init_mask.masked_fill_((concept_label == -1).unsqueeze(1), 0) concept_probs.append(init_mask) head = head.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1) tail = tail.unsqueeze(1).expand(triple_prob.size(0), triple_prob.size(1), -1) for step in range(iteration): ''' Calculate triple head score ''' node_score = concept_probs[-1] triple_head_score = node_score.gather(2, head) triple_head_score.masked_fill_((triple_label == -1).unsqueeze(1), 0) ''' Method: - avg: s(v) = Avg_{u \in N(v)} gamma * s(u) + R(u->v) - max: s(v) = max_{u \in N(v)} gamma * s(u) + R(u->v) ''' update_value = triple_head_score * gamma + triple_prob out = torch.zeros_like(node_score).to(node_score.device).float() if method == "max": scatter_max(update_value, tail, dim=-1, out=out) elif method == "avg": scatter_mean(update_value, tail, dim=-1, out=out) out.masked_fill_((concept_label == -1).unsqueeze(1), 0) concept_probs.append(out) ''' Natural decay of concept that is multi-hop away from source ''' total_concept_prob = final_mask * -1e5 for prob in concept_probs[1:]: total_concept_prob += prob # bsz x L x mem return total_concept_prob
def aggregate_multi(self, ind_lst, pos_lst, row_rel, query_rel): u""" Args: ind_lst (list) -- List of indices of entity (tuple) to which each row of row_rel belongs. pos_lst (list) -- List of indices of batch to which each row of row_rel targets. row_rel (Tensor) -- (n_rel x emb_dim) query_rel (Tensor) -- (n_batch x emb_dim) Return: out (Tensor) -- (n_batch x emb_dim) """ device = row_rel.device mode = self.args.aggregation ind_lst = torch.LongTensor(ind_lst).to(device) pos_lst = torch.LongTensor(pos_lst).to(device) if mode == "attention": row_query = query_rel[pos_lst] # Scores between each row relation and query relations. o = torch.sum(row_rel * row_query, dim=1, keepdim=True) # Calculate maximum score for each entity (tuple) to circumvent exp overflow. min_o = torch.min(o).item() m,_ = scatter_max(o-min_o, ind_lst, dim=0) m = m + min_o m = m[ind_lst] #(n_row x 1) # Calculate weight for each row relations. a = torch.exp(o - m) sum_a = scatter_add(a, ind_lst, dim=0) sum_a = sum_a[ind_lst] w = a / sum_a #(n_row x 1) # Calculate weighted mean for each batch. weighted_row_rel = w * row_rel out = scatter_add(weighted_row_rel, ind_lst, dim=0) #(n_batch x emb_dim) elif mode == "scaled-attention": row_query = query_rel[pos_lst] # Scores between each row relation and query relations. o = torch.sum(row_rel * row_query, dim=1, keepdim=True) o = o / row_rel.size(1)**0.5 #(n_row x 1) # Calculate maximum score for each entity (tuple) to circumvent exp overflow. min_o = torch.min(o).item() m,_ = scatter_max(o-min_o, ind_lst, dim=0) m = m + min_o m = m[ind_lst] #(n_row x 1) # Calculate weight for each row relations. a = torch.exp(o - m) sum_a = scatter_add(a, ind_lst, dim=0) sum_a = sum_a[ind_lst] w = a / sum_a #(n_row x 1) # Calculate weighted mean for each batch. weighted_row_rel = w * row_rel out = scatter_add(weighted_row_rel, ind_lst, dim=0) #(n_batch x emb_dim) return out
def run_inner_loop(self, inputs): x, start, end = inputs # print("2:", torch.cuda.max_memory_allocated() / 1024**3) # torch.cuda.reset_peak_memory_stats() # Apply edge network edge_inputs = torch.cat([x[start], x[end]], dim=1) e = self.edge_network(edge_inputs) e = torch.sigmoid(e) # print("3:", torch.cuda.max_memory_allocated() / 1024**3) # torch.cuda.reset_peak_memory_stats() if self.hparams["aggregation"] == "sum": messages = torch.cat( [ scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]), scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0]), ], dim=-1, ) elif self.hparams["aggregation"] == "max": messages = torch.cat( [ scatter_max(e * x[start], end, dim=0, dim_size=x.shape[0])[0], scatter_max(e * x[end], start, dim=0, dim_size=x.shape[0])[0], ], dim=-1, ) elif self.hparams["aggregation"] == "sum_max": messages = torch.cat( [ scatter_max(e * x[start], end, dim=0, dim_size=x.shape[0])[0], scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]), scatter_max(e * x[end], start, dim=0, dim_size=x.shape[0])[0], scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0]), ], dim=-1, ) # print("4:", torch.cuda.max_memory_allocated() / 1024**3) # torch.cuda.reset_peak_memory_stats() node_inputs = torch.cat([messages, x], dim=1) x = self.node_network(node_inputs) return x
def _add_to_memory(self, key, payload, batch_ids=None): num_add = key.size(0) self._key[self._num_items:self._num_items + num_add] = key for (payload, add_pl) in zip(self._payload, payload): payload[self._num_items:self._num_items + num_add] = add_pl if self.batch_size > 1: self._batch_ids[self._num_items:self._num_items + num_add] = batch_ids self._num_items += num_add if self.batch_size > 1 and (self._batch_num_items_lb < self.capacity).any(): # Add number of items per batch entry, and see if for some entries we exceed the capacity # for the first time so we can define a bound new_batch_num_items = self._batch_num_items_lb + bincount( batch_ids, minlength=self.batch_size) new_bound = (new_batch_num_items >= self.capacity) & ( self._batch_num_items_lb < self.capacity) self._batch_num_items_lb = new_batch_num_items if new_bound.any(): # Note: this is not a very tight bound, we can get a tighter bound by taking the max of the first # 'capacity' values only per instance, instead of all values (as in the without batch case) # Better is to take the k-th largest value, which is not done here (to save computation) # but when the queue is actually reduced self.bound = torch.where( new_bound, scatter_max(self._key[:self._num_items], self._batch_ids[:self._num_items].long(), dim_size=self.batch_size)[0], self.bound) elif self._num_items >= self.capacity and self.bound is None and self.batch_size == 1: # As soon as we have more than capacity items for the first time, we can define a bound # using the first 'capacity' items (gives slightly better bound than using _num_items items) self.bound = self._key[:self.capacity].max().cpu()
def gat_layer(self, input, adj, genPath=False, eluF=True): N = input.size()[0] edge = adj._indices() h = torch.mm(input, self.W) h = h+self.bias # h: N x out # Self-attention on the nodes - Shared attention mechanism edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() # edge_h: 2*D x E edge_att = self.a.mm(edge_h).squeeze() edge_e_a = self.leakyrelu(edge_att) # edge_e_a: E attetion score for each edge if genPath: with torch.no_grad(): edge_weight = edge_e_a p_a_e = edge_weight - scatter_max(edge_weight, edge[0,:], dim=0, dim_size=N)[0][edge[0,:]] p_a_e = p_a_e.exp() p_a_e = p_a_e / (scatter_add(p_a_e, edge[0,:], dim=0, dim_size=N)[edge[0,:]]\ +torch.Tensor([9e-15]).cuda()) scisp = convert.to_scipy_sparse_matrix(edge, p_a_e, N) scipy.sparse.save_npz(os.path.join(genPath, 'attmat_{:s}.npz'.format(self.layerN)), scisp) edge_e = torch.exp(edge_e_a - torch.max(edge_e_a)) # edge_e: E e_rowsum = spmm(edge, edge_e, N, torch.ones(size=(N,1)).cuda()) # e_rowsum: N x 1 edge_e = self.dropout(edge_e) # add dropout improve from 82.4 to 83.8 # edge_e: E h_prime = spmm(edge, edge_e, N, h) h_prime = h_prime.div(e_rowsum+torch.Tensor([9e-15]).cuda()) # h_prime: N x out if self.concat and eluF: return F.elu(h_prime) else: return h_prime
def forward(self, x_scores, x_relations, berts, edge_indices, softmax_edge_indices, n_program, max_y_score_len): feature = self.getFeature(x_scores) if self.reduce == 'max': feature = torch.cat([ feature, torch.zeros(size=[1, feature.shape[1]]).to(feature.device) ], dim=0) expanded_feature = feature[edge_indices[0]] gathered_feature, _ = self.gatherFeature( expanded_feature, edge_indices[1], dim_size=torch.max(edge_indices[1]).long() + 1, dim=0) processed_feature = self.processFeature(gathered_feature) processed_feature = torch.cat([processed_feature, berts], dim=-1) score = self.getScore(processed_feature) assert not (score == 0).any() score, _ = torch_scatter.scatter_max(score, softmax_edge_indices[1], dim_size=n_program * max_y_score_len, dim=0) #! to update here score = torch.where(score != 0, score, torch.tensor(-float('inf')).to(score.device)) score = torch.reshape(score, [n_program, max_y_score_len]) relation = self.getRelation(x_relations) return score, relation
def softmax(self, x, index, num=None): x = x - torch_scatter.scatter_max(x, index, dim=0, dim_size=num)[0][index] x = x.exp() x = x / (torch_scatter.scatter_add(x, index, dim=0, dim_size=num)[index] + 1e-16) return x
def forward(self, pt_fea, xy_ind): cur_dev = pt_fea[0].get_device() # concate everything cat_pt_ind = [] for i_batch in range(len(xy_ind)): cat_pt_ind.append(F.pad(xy_ind[i_batch], (1, 0), 'constant', value=i_batch)) cat_pt_fea = torch.cat(pt_fea, dim=0) cat_pt_ind = torch.cat(cat_pt_ind, dim=0) pt_num = cat_pt_ind.shape[0] # shuffle the data shuffled_ind = torch.randperm(pt_num, device=cur_dev) cat_pt_fea = cat_pt_fea[shuffled_ind, :] cat_pt_ind = cat_pt_ind[shuffled_ind, :] # unique xy grid index unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind, return_inverse=True, return_counts=True, dim=0) unq = unq.type(torch.int64) # process feature processed_cat_pt_fea = self.PPmodel(cat_pt_fea) pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0] if self.fea_compre: processed_pooled_data = self.fea_compression(pooled_data) else: processed_pooled_data = pooled_data return unq, processed_pooled_data
def softmax(src, index, num_nodes=None): r"""Sparse softmax of all values from the :attr:`src` tensor at the indices specified in the :attr:`index` tensor along the first dimension. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ # print('src = ',src.size(), ' index = ',index.size()) if num_nodes is None: # num_nodes = maybe_num_nodes(index, num_nodes) num_nodes = index.max() + 1 out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] # print('src - scatter_max = ',out.size(),out) out = out.exp() # print('exp = ',out.size(),out, 'index size = ',index.size()) out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) # print('out = ',out.size(),out) return out
def forward(self, x, edge_index, edge_weight=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) # NxF x = x.unsqueeze(-1) if x.dim() == 1 else x # Add Self Loops fill_value = 1 num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) edge_index, edge_weight = add_remaining_self_loops(edge_index=edge_index, edge_weight=edge_weight, fill_value=fill_value, num_nodes=num_nodes.sum()) N = x.size(0) # total num of nodes in batch # ExF x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[1]] x_j = x[edge_index[1]] #---Master query formation--- # NxF X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0) # NxF M_q = self.lin_q(X_q) # ExF M_q = M_q[edge_index[0].tolist()] score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1)) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[0], num_nodes=num_nodes.sum()) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout_att, training=self.training) # ExF v_j = x_j * score.view(-1, 1) #---Aggregation--- # NxF out = scatter_add(v_j, edge_index[0], dim=0) #---Cluster Selection # Nx1 fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1) perm = topk(x=fitness, ratio=self.ratio, batch=batch) x = out[perm] * fitness[perm].view(-1, 1) #---Maintaining Graph Connectivity batch = batch[perm] edge_index, edge_weight = graph_connectivity( device = x.device, perm=perm, edge_index=edge_index, edge_weight=edge_weight, score=score, ratio=self.ratio, batch=batch, N=N) return x, edge_index, edge_weight, batch, perm
def forward(self, x, pos, batch): # FPS sampling id_clusters = fps(pos, ratio=self.ratio, batch=batch) # compute for each cluster the k nearest points sub_batch = batch[id_clusters] if batch is not None else None # beware of self loop id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch, batch_y=sub_batch) # transformation of features through a simple MLP x = self.mlp(x) # Max pool onto each cluster the features from knn in points x_out, _ = scatter_max(x[id_k_neighbor[1]], id_k_neighbor[0], dim_size=id_clusters.size(0), dim=0) # keep only the clusters and their max-pooled features sub_pos, out = pos[id_clusters], x_out return out, sub_pos, sub_batch
def forward(self, graph_data, correct_candidate_idxs): gnn_output: GnnOutput = self._gnn(**graph_data) # Code assumes that there is one slot per-graph, which is true for the original data candidate_node_representations = gnn_output.output_node_representations[ gnn_output.node_idx_references[ "candidate_nodes"]] # [num_candidate_nodes, H_out] candidate_nodes_slot_idx = gnn_output.node_graph_idx_reference[ "candidate_nodes"] # [num_candidate_nodes] slot_representations = gnn_output.output_node_representations[ gnn_output. node_idx_references["slot_node_idx"]] # [num_slot_nodes, H_out] slot_representations_per_candidate = slot_representations[ candidate_nodes_slot_idx] # [num_candidate_nodes, H] candidate_scores = self.__candidate_scores( torch.cat((candidate_node_representations, slot_representations_per_candidate), dim=-1)).squeeze(-1) candidate_nodes_logprobs = scatter_log_softmax( src=candidate_scores, index=candidate_nodes_slot_idx, eps=0) with torch.no_grad(): self.__sum_acc += int((scatter_max( candidate_scores, index=candidate_nodes_slot_idx)[1] == correct_candidate_idxs ).sum()) self.__num_samples += int(slot_representations.shape[0]) return -candidate_nodes_logprobs[correct_candidate_idxs].mean()
def decode(self, y, ext_x, prev_states, prev_context, encoder_features, encoder_mask): # forward one step lstm # y : [b] embedded = self.embedding(y.unsqueeze(1)) lstm_inputs = self.reduce_layer(torch.cat([embedded, prev_context], 2)) output, states = self.lstm(lstm_inputs, prev_states) context, energy = self.attention(output, encoder_features, encoder_mask) concat_input = torch.cat((output, context), 2).squeeze(1) logit_input = torch.tanh(self.concat_layer(concat_input)) logit = self.logit_layer(logit_input) # [b, |V|] if config.use_pointer: batch_size = y.size(0) num_oov = max(torch.max(ext_x - self.vocab_size + 1), 0) zeros = torch.zeros((batch_size, num_oov), device=config.device) extended_logit = torch.cat([logit, zeros], dim=1) out = torch.zeros_like(extended_logit) - INF out, _ = scatter_max(energy, ext_x, out=out) out = out.masked_fill(out == -INF, 0) logit = extended_logit + out logit = logit.masked_fill(logit == -INF, 0) # forcing UNK prob 0 logit[:, UNK_ID] = -INF return logit, states, context
def softmax_weights(self): """Compute the softmax of the outgoing edge weights by node. Use the shift property of softmax for stability. Returns: Graph: new Graph """ max_out_weight_per_node, _ = scatter_max( src=self.edges, index=self.senders, dim=0, dim_size=self.n_node, # fill_value=-1e20 ) shifted_weights = self.edges - max_out_weight_per_node[self.senders] exp_weights = shifted_weights.exp() normalizer = scatter_add(src=exp_weights, index=self.senders, dim=0, dim_size=self.n_node) sender_normalizer = normalizer[self.senders] normalized_weights = exp_weights / sender_normalizer if any_nan(normalized_weights): logging.warning( "NaN weight after normalization in graph `softmax_weights`") return self.update(edges=normalized_weights)
def max_edge_weight_per_node(self) -> torch.Tensor: """Returns weight and edge_idx of the max weight outgoing edge for each node Returns: torch.Tensor: [n_node, ] weights, [n_node, ] indices """ return scatter_max(self.edges.squeeze(), self.senders)
def real_softmax(src, index, num_nodes=None): r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ num_nodes = maybe_num_nodes(index, num_nodes) src = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] out = src.exp() assert not nan_or_inf(out) oout = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) assert not nan_or_inf(oout) return oout
def __call__(self, data, norm=True): row, col = data.edge_index N = data.num_nodes deg = degree(row, N, dtype=torch.float) if norm: deg = deg / deg.max() deg_col = deg[col] min_deg, _ = scatter_min(deg_col, row, dim_size=N) min_deg[min_deg > 10000] = 0 max_deg, _ = scatter_max(deg_col, row, dim_size=N) max_deg[max_deg < -10000] = 0 mean_deg = scatter_mean(deg_col, row, dim_size=N) std_deg = scatter_std(deg_col, row, dim_size=N) x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1) if data.x is not None: data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x data.x = torch.cat([data.x, x], dim=-1) else: data.x = x return data
def forward(self, x, edge_index): edge_index, _ = remove_self_loops(edge_index) edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x row, col = edge_index if self.pool == 'mean': out = torch.matmul(x, self.weight) if self.bias is not None: out = out + self.bias out = self.act(out) out = scatter_mean(out[col], row, dim=0, dim_size=out.size(0)) elif self.pool == 'max': out = torch.matmul(x, self.weight) if self.bias is not None: out = out + self.bias out = self.act(out) out, _ = scatter_max(out[col], row, dim=0, dim_size=out.size(0)) elif self.pool == 'add': out = torch.matmul(x, self.weight) if self.bias is not None: out = out + self.bias out = self.act(out) out = scatter_add(x[col], row, dim=0, dim_size=x.size(0)) else: print('pooling not defined!') if self.normalize: out = F.normalize(out, p=2, dim=-1) return out
def forward(self, data): x_1 = self.gcu_1(data.pos, data.tpl_edge_index, data.geo_edge_index) x_2 = self.gcu_2(x_1, data.tpl_edge_index, data.geo_edge_index) x_3 = self.gcu_3(x_2, data.tpl_edge_index, data.geo_edge_index) x_4 = self.mlp_glb(torch.cat([x_1, x_2, x_3], dim=1)) x_global, _ = scatter_max(x_4, data.batch, dim=0) return x_global
def __get_updated_memory__(self, n_id): self.__assoc__[n_id] = torch.arange(n_id.size(0), device=n_id.device) # Compute messages (src -> dst). msg_s, t_s, src_s, dst_s = self.__compute_msg__( n_id, self.msg_s_store, self.msg_s_module) # Compute messages (dst -> src). msg_d, t_d, src_d, dst_d = self.__compute_msg__( n_id, self.msg_d_store, self.msg_d_module) # Aggregate messages. idx = torch.cat([src_s, src_d], dim=0) msg = torch.cat([msg_s, msg_d], dim=0) t = torch.cat([t_s, t_d], dim=0) aggr = self.aggr_module(msg, self.__assoc__[idx], t, n_id.size(0)) # Get local copy of updated memory. memory = self.gru(aggr, self.memory[n_id]) # Get local copy of updated `last_update`. dim_size = self.last_update.size(0) last_update = scatter_max(t, idx, dim=0, dim_size=dim_size)[0][n_id] return memory, last_update
def take_action_deterministic_batch_dqn(target_net, player, batch_instances): with torch.no_grad(): # We compute the target values batch = batch_instances.G_torch.batch mask_values = batch_instances.J.eq(0)[:, 0] action_values = target_net( batch_instances.G_torch, batch_instances.n_nodes, batch_instances.Omegas, batch_instances.Phis, batch_instances.Lambdas, batch_instances.Omegas_norm, batch_instances.Phis_norm, batch_instances.Lambdas_norm, batch_instances.J, ) action_values = action_values[mask_values] batch = batch[mask_values] # if it's the turn of the attacker if player == 1: # we take the argmin values, actions = scatter_min(action_values, batch, dim=0) else: # we take the argmax values, actions = scatter_max(action_values, batch, dim=0) return actions.view(-1).tolist()