def forward(self, x, e, o, edge_index): start, end = edge_index # Aggregate edge-weighted incoming/outgoing features mi = scatter_add(e[:, None] * x[start], end, dim=0, dim_size=x.shape[0]) mo = scatter_add(e[:, None] * x[end], start, dim=0, dim_size=x.shape[0]) global_i = scatter_add(torch.ger(e, o), end, dim=0, dim_size=x.shape[0]) global_o = scatter_add(torch.ger(e, o), start, dim=0, dim_size=x.shape[0]) # print(mi.shape, mo.shape, global_i.shape, global_o.shape, x.shape, (torch.cat([mi, mo, global_i, global_o, x], dim=1)).shape) node_inputs = torch.cat([mi, mo, global_i, global_o, x], dim=1) return self.network(node_inputs)
def forward(self, x, batch, size=None): """""" x = x.unsqueeze(-1) if x.dim() == 1 else x size = batch[-1].item() + 1 if size is None else size gate = self.gate_nn(x).view(-1, 1) x = self.nn(x) if self.nn is not None else x assert gate.dim() == x.dim() and gate.size(0) == x.size(0) gate = softmax(gate, batch, num_nodes=size) out = scatter_add(gate * x, batch, dim=0, dim_size=size) return out
def forward(self, x, edge_index): # Encode the graph features into the hidden space input_x = x x = self.node_encoder(x) start, end = edge_index # Loop over iterations of edge and node networks for i in range(self.hparams["n_graph_iters"]): # Previous hidden state # x0 = x # Compute new edge score edge_inputs = torch.cat([x[start], x[end]], dim=1) e = checkpoint(self.edge_network, edge_inputs) e = torch.sigmoid(e) # Sum weighted node features coming into each node # weighted_messages_in = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]) # weighted_messages_out = scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0]) weighted_messages = scatter_add( e * x[start], end, dim=0, dim_size=x.shape[0]) + scatter_add( e * x[end], start, dim=0, dim_size=x.shape[0]) # Compute new node features # node_inputs = torch.cat([x, weighted_messages_in, weighted_messages_out], dim=1) node_inputs = torch.cat([x, weighted_messages], dim=1) # node_inputs = weighted_messages + x x = checkpoint(self.node_network, node_inputs) # Residual connection # x = x + x0 # Compute final edge scores; use original edge directions only clf_inputs = torch.cat([x[start], x[end]], dim=1) return checkpoint(self.edge_network, clf_inputs).squeeze(-1)
def colcount(self) -> torch.Tensor: colcount = self._colcount if colcount is not None: return colcount colptr = self._colptr if colptr is not None: colcount = colptr[1:] - colptr[:-1] else: colcount = scatter_add(torch.ones_like(self._col), self._col, dim_size=self._sparse_sizes[1]) self._colcount = colcount return colcount
def forward(self, x, edge_index, pseudo): """""" # See https://github.com/shchur/gnn-benchmark for the reference # TensorFlow implementation. x = x.unsqueeze(-1) if x.dim() == 1 else x pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo row, col = edge_index F, (E, D) = x.size(1), pseudo.size() gaussian = -0.5 * (pseudo.view(E, 1, D) - self.mu.view(1, F, D))**2 gaussian = torch.exp(gaussian / (1e-14 + self.sigma.view(1, F, D)**2)) gaussian = gaussian.prod(dim=-1) # Normalize gaussians in edge dimension. gaussian_mean = scatter_add(gaussian, row, dim=0, dim_size=x.size(0)) gaussian = gaussian / (1e-14 + gaussian_mean[row]).view(E, F) out = scatter_add(x[col] * gaussian, row, dim=0, dim_size=x.size(0)) out = self.lin(out) return out
def i_and_u(pred, target, num_classes, batch=None): r"""Computes intersection and union of predictions. Args: pred (LongTensor): The predictions. target (LongTensor): The targets. num_classes (int): The number of classes. batch (LongTensor): The assignment vector which maps each pred-target pair to an example. :rtype: (:class:`LongTensor`, :class:`LongTensor`) """ pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes) if batch is None: i = (pred & target).sum(dim=0) u = (pred | target).sum(dim=0) else: i = scatter_add(pred & target, batch, dim=0) u = scatter_add(pred | target, batch, dim=0) return i, u
def __init__(self, batch: GraphBatch): self._batch = batch self._pooling_functions = { 'mean': lambda src, idx: torch_scatter.scatter_mean( src, idx, dim=0, dim_size=batch.num_graphs), 'sum': lambda src, idx: torch_scatter.scatter_add( src, idx, dim=0, dim_size=batch.num_graphs), 'max': lambda src, idx: torch_scatter.scatter_max( src, idx, dim=0, dim_size=batch.num_graphs)[0], }
def forward(self, edges, vertices, target_idx): x = vertices x = torch.cat([x] + [torch_scatter.scatter_add(x[edges[:,1]], edges[:,0], dim=0, dim_size=vertices.size(0))], dim=1) x = self.l1(x) identity = x x = F.relu(self.l2(torch.cat([x] + [torch_scatter.scatter_add(x[edges[:,1]], edges[:,0], dim=0, dim_size=vertices.size(0))], dim=1))) x = x / (torch.norm(x, p=2, dim=1).unsqueeze(0).t() + 0.000001) x += identity # residual connection x = self.dropout(F.relu(self.l3(torch.cat([x] + [torch_scatter.scatter_add(x[edges[:,1]], edges[:,0], dim=0, dim_size=vertices.size(0))], dim=1)))) x = x / (torch.norm(x, p=2, dim=1).unsqueeze(0).t() + 0.000001) x_target = x[target_idx] x = torch.squeeze(x, dim=1) x_target = self.l4(x_target) x_target = torch.unsqueeze(x_target, dim=0) return x_target
def forward(self, x, edge_index, edge_attr, batch_mask): x = x @ self.weight x = self.norm(x, batch_mask) alpha, alpha_index = self.attention(x, edge_index, edge_attr) row, col = alpha_index num_nodes = x.size(0) deg = scatter_add(alpha.abs(), row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * alpha * deg_inv_sqrt[col] out = self.my_cast(norm, x[col]) out = scatter_add(out, row, dim=0, dim_size=x.size(0)) if self.bias is not None: out = out + self.bias return out, alpha, alpha_index
def __matmul__(self, node_signal: torch.Tensor) -> torch.Tensor: """ product = input * weight = node_signal * W """ assert node_signal.shape[0] == self.n_node assert self.edges is not None and self.edges.squeeze().dim()==1 senders_features = node_signal[self.senders] broadcast_edges = self.edges.view(-1, *([1]* (node_signal.dim() -1))) weighted_senders = senders_feaures * broadcast_edges node_results = scatter_add(src= weighted_senders, index = self.receivers, dim=0, dim_size= self.n_node) return node_results
def forward(self, x, hyperedge_index, hyperedge_weight=None): r""" Args: x (Tensor): Node feature matrix :math:`\mathbf{X}` hyper_edge_index (LongTensor): Hyperedge indices from :math:`\mathbf{H}`. hyperedge_weight (Tensor, optional): Sparse hyperedge weights from :math:`\mathbf{W}`. (default: :obj:`None`) """ x = torch.matmul(x, self.weight) alpha = None if self.use_attention: x = x.view(-1, self.heads, self.out_channels) x_i, x_j = x[hyperedge_index[0]], x[hyperedge_index[1]] alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, hyperedge_index[0], num_nodes=x.size(0)) alpha = F.dropout(alpha, p=self.dropout, training=self.training) if hyperedge_weight is None: D = degree(hyperedge_index[0], x.size(0), x.dtype) else: D = scatter_add(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0], dim=0, dim_size=x.size(0)) D = 1.0 / D D[D == float("inf")] = 0 if hyperedge_index.numel() == 0: num_edges = 0 else: num_edges = hyperedge_index[1].max().item() + 1 B = 1.0 / degree(hyperedge_index[1], num_edges, x.dtype) B[B == float("inf")] = 0 if hyperedge_weight is not None: B = B * hyperedge_weight self.flow = 'source_to_target' out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha) self.flow = 'target_to_source' out = self.propagate(hyperedge_index, x=out, norm=D, alpha=alpha) if self.concat is True: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias return out
def group_data(data, cluster=None, unique_pos_indices=None, mode="last", skip_keys=[]): """ Group data based on indices in cluster. The option ``mode`` controls how data gets agregated within each cluster. Parameters ---------- data : Data [description] cluster : torch.Tensor Tensor of the same size as the number of points in data. Each element is the cluster index of that point. unique_pos_indices : torch.tensor Tensor containing one index per cluster, this index will be used to select features and labels mode : str Option to select how the features and labels for each voxel is computed. Can be ``last`` or ``mean``. ``last`` selects the last point falling in a voxel as the representent, ``mean`` takes the average. skip_keys: list Keys of attributes to skip in the grouping """ assert mode in ["mean", "last"] if mode == "mean" and cluster is None: raise ValueError( "In mean mode the cluster argument needs to be specified") if mode == "last" and unique_pos_indices is None: raise ValueError( "In last mode the unique_pos_indices argument needs to be specified" ) num_nodes = data.num_nodes for key, item in data: if bool(re.search("edge", key)): raise ValueError("Edges not supported. Wrong data type.") if key in skip_keys: continue if torch.is_tensor(item) and item.size(0) == num_nodes: if mode == "last" or key == "batch" or key == SaveOriginalPosId.KEY: data[key] = item[unique_pos_indices] elif mode == "mean": if key == "y": item_min = item.min() item = F.one_hot(item - item_min) item = scatter_add(item, cluster, dim=0) data[key] = item.argmax(dim=-1) + item_min else: data[key] = scatter_mean(item, cluster, dim=0) return data
def message(self, x_i, x_j, edge_index, num_nodes): alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, edge_index[0], None, num_nodes) if self.mod == "additive": ones = torch.ones_like(alpha) h = x_j * ones.view(-1, self.heads, 1) h = torch.mul(self.w, h) return x_j * alpha.view(-1, self.heads, 1) + h elif self.mod == "scaled": ones = alpha.new_ones(edge_index[0].size()) degree = scatter_add( ones, edge_index[0], dim_size=num_nodes)[edge_index[0]].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 degree = degree.unsqueeze(-2) return torch.mul(x_j * alpha.view(-1, self.heads, 1), degree) elif self.mod == "f-additive": alpha = torch.where(alpha > 0, alpha + 1, alpha) elif self.mod == "f-scaled": ones = alpha.new_ones(edge_index[0].size()) degree = scatter_add( ones, edge_index[0], dim_size=num_nodes)[edge_index[0]].unsqueeze(-1) alpha = alpha * degree else: alpha = alpha # origin return x_j * alpha.view(-1, self.heads, 1)
def norm(self, edge_index, num_nodes, edge_weight, improved=False, dtype=None): with torch.no_grad(): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) edge_weight = edge_weight.view(-1) assert edge_weight.size(0) == edge_index.size(1) row, col = edge_index deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-1) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return deg_inv_sqrt[row], deg_inv_sqrt[col]
def forward( self, x: torch.Tensor, # num_records * 3 cond: torch.Tensor, # ragged seg_ids: torch.Tensor): # size(num_records) # x = self.bn(x) x = x[seg_ids, ...] x = torch.cat((x, cond), dim=-1) x = F.relu(self.bn1(self.fc1(x))) x = scatter_add(x, seg_ids, dim=0) x = self.fc2(F.relu(self.bn2(x))) mu, var = torch.split(x, x.size(-1) // 2, -1) var = F.softplus(var) / math.log(2) return mu, var
def forward(self, x, edge_index): x_in = x edge_index, _ = add_remaining_self_loops(edge_index) if self.norm == 'dropedge': if self.training: edge_index, _ = dropout_adj(edge_index, force_undirected=True, training=True) else: edge_index, _ = dropout_adj(edge_index, force_undirected=True, training=False) row, col = edge_index deg = degree(row) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # x = self.linear(x) if self.norm == 'neighbornorm': x_j = self.normlayer(x, edge_index) else: x_j = x[col] x_j = norm.view(-1, 1) * x_j out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0)) out = self.linear(out) if self.activation: out = F.relu(out) if self.norm == 'batchnorm': out = self.normlayer(out) elif self.norm == 'layernorm': out = self.normlayer(out) elif self.norm == 'pairnorm': out = self.normlayer(out) elif self.norm == 'nodenorm': out = self.normlayer(out) if self.residual: out = x_in + out if self.dropout: out = F.dropout(out, p=0.5, training=self.training) return out
def calc_log_prob(init_prob, traj_prob, rep_init, rep_rows, final_samples): if rep_rows.shape[0]: zeros = torch.zeros(rep_init.shape[0], 1).to(init_prob.device) nonstop = sampler.pred_stop(rep_init, zeros)[0] nonstop = scatter_add(nonstop, rep_rows, dim=0, dim_size=final_samples.shape[0]) else: nonstop = 0 ones = torch.ones(final_samples.shape[0], 1).to(init_prob.device) last_stop = sampler.pred_stop(final_samples, ones)[0] return init_prob + traj_prob + nonstop + last_stop
def forward(self, x, edge_index, edge_weight=None): """""" # edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) # print(x.size(), edge_index.size()) row, col = edge_index batch, num_nodes, num_edges, K = x.size(0), x.size(1), row.size( 0), self.weight.size(0) edge_weight = edge_weight.view(-1) assert edge_weight.size(0) == edge_index.size(1) ###degree matrix deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) # Compute normalized and rescaled Laplacian. deg = deg.pow(-0.5) deg[torch.isinf(deg)] = 0 lap = -deg[row] * edge_weight * deg[col] ###Rescale the Laplacian eigenvalues in [-1, 1] ##rescale: 2L/lmax-I; lmax=1.0 fill_value = -0.05 ##-0.5 edge_index, lap = add_self_loops(edge_index, lap, fill_value, num_nodes) lap *= 2 ######################################## # Perform filter operation recurrently. Tx_0 = x out = torch.matmul(Tx_0, self.weight[0]) if K > 1: Tx_1 = sparse_dense_mat_mul( edge_index, lap, num_nodes, x.permute(1, 2, 0).contiguous().view( (num_nodes, -1))).view((num_nodes, -1, batch)).permute( 2, 0, 1 ) # sparse_dense_mat_mul(edge_index, lap, num_nodes, x) out = out + torch.matmul(Tx_1, self.weight[1]) for k in range(2, K): Tx_2 = 2 * sparse_dense_mat_mul( edge_index, lap, num_nodes, x.permute(1, 2, 0).contiguous().view((num_nodes, -1))).view( (num_nodes, -1, batch)).permute(2, 0, 1) - Tx_0 # 2 * sparse_dense_mat_mul(edge_index, lap, num_nodes, Tx_1) - Tx_0 out = out + torch.matmul(Tx_2, self.weight[k]) Tx_0, Tx_1 = Tx_1, Tx_2 if self.bias is not None: out = out + self.bias return out
def forward(self, data): fingerprint = torch.zeros((data.batch.shape[0], self.fp_size), dtype=torch.float) out = data.x print(type(data.edge_index)) for idx, loop in enumerate(self.loops): updated_atom_features, updated_fingerprint = loop( out, data.edge_index) out = updated_atom_features fingerprint += updated_fingerprint return scatter_add(fingerprint, data.batch, dim=0)
def forward(self, x, edge_index, edge_weight=None, size=None): """""" num_nodes = x.shape[0] h = torch.matmul(x, self.weight) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=x.dtype, device=edge_index.device) edge_index, edge_weight = remove_self_loops(edge_index=edge_index, edge_attr=edge_weight) deg = scatter_add(edge_weight, edge_index[0], dim=0, dim_size=num_nodes) #+ 1e-10 h_j = edge_weight.view(-1, 1) * h[edge_index[1]] aggr_out = scatter_add(h_j, edge_index[0], dim=0, dim_size=num_nodes) out = (deg.view(-1, 1) * self.lin1(x) + aggr_out) + self.lin2(x) edge_index, edge_weight = add_self_loops(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes) return out
def forward(self, data): data.x = F.elu(self.conv1(data.x, data.edge_index)) data.x = F.elu(self.conv2(data.x, data.edge_index)) data.x = F.elu(self.conv3(data.x, data.edge_index)) x_1 = scatter_add(data.x, data.batch, dim=0) x = x_1 if args.no_train: x = x.detach() x = F.elu(self.fc1(x)) x = F.elu(self.fc2(x)) x = self.fc3(x) return F.log_softmax(x, dim=1)
def __init__(self, graph: _BaseGraph): self._graph = graph # TODO move these to the class definition or somewhere else self._pooling_functions = { 'mean': lambda src, idx: torch_scatter.scatter_mean( src, idx, dim=0, dim_size=graph.num_nodes), 'sum': lambda src, idx: torch_scatter.scatter_add( src, idx, dim=0, dim_size=graph.num_nodes), 'max': lambda src, idx: torch_scatter.scatter_max( src, idx, dim=0, dim_size=graph.num_nodes)[0], }
def forward(self, graphs: tg.GraphBatch): nodes = F.relu(self.g_n(graphs.node_features)) globals = self.h_n( torch_scatter.scatter_add(nodes, segment_lengths_to_ids( graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) return graphs.evolve(num_edges=0, edge_features=None, node_features=None, global_features=globals, senders=None, receivers=None)
def forward(self, x, edge_index, edge_attr, u, history_vector, batch): gate = self.node_mlp_1(x) assert gate.dim() == x.dim() and gate.size(0) == x.size(0) # gate = torch.bmm(x.unsqueeze(1) , self.ques_nn(u)[batch].unsqueeze(2)).squeeze(-1) # assert gate.dim() == x.dim() and gate.size(0) == x.size(0) gate = torch_geometric.utils.softmax(gate, batch, num_nodes=None) new_history_vector = scatter_add(gate * x, batch, dim=0, dim_size=None) return gate, new_history_vector
def forward(f, shapes, lmax, device): r_max = 1.1 x = torch.ones(4, 1) batch = Batch.from_data_list([ DataNeighbors(x, shape, r_max, self_interaction=False) for shape in shapes ]) batch = batch.to(device) sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr, 'component') out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3) out = scatter_add(out, batch.batch, dim=0) out = torch.tanh(out) return out
def _sofmax(indexes, egde_values): """ :param indexes: nodes of each edge :param egde_values: values of each edge :return: normalized values of edges considering nodes """ edge_values = torch.exp(egde_values) row_sum = scatter_add(edge_values, indexes, dim=0) edge_softmax = edge_values / row_sum[indexes, :, :] return edge_softmax
def forward(self, x, edge_index, edge_attr, autoregressive_x=None, node_mask=None): ''' :param x: tuple (s, V) of `torch.Tensor` :param edge_index: array of shape [2, n_edges] :param edge_attr: tuple (s, V) of `torch.Tensor` :param autoregressive_x: tuple (s, V) of `torch.Tensor`. If not `None`, will be used as srcqq node embeddings for forming messages where src >= dst. The corrent node embeddings `x` will still be the base of the update and the pointwise feedforward. :param node_mask: array of type `bool` to index into the first dim of node embeddings (s, V). If not `None`, only these nodes will be updated. ''' if autoregressive_x is not None: src, dst = edge_index mask = src < dst edge_index_forward = edge_index[:, mask] edge_index_backward = edge_index[:, ~mask] edge_attr_forward = tuple_index(edge_attr, mask) edge_attr_backward = tuple_index(edge_attr, ~mask) dh = tuple_sum( self.conv(x, edge_index_forward, edge_attr_forward), self.conv(autoregressive_x, edge_index_backward, edge_attr_backward) ) count = scatter_add(torch.ones_like(dst), dst, dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1) dh = dh[0] / count, dh[1] / count.unsqueeze(-1) else: dh = self.conv(x, edge_index, edge_attr) if node_mask is not None: x_ = x x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) dh = self.ff_func(x) x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) if node_mask is not None: x_[0][node_mask], x_[1][node_mask] = x[0], x[1] x = x_ return x
def propagate_homo(self, graph, input, typ): [source, target] = graph signal = input[..., source, :] # shape [..., E, in_features] if self.bias is None: message = F.linear( signal, self.weight[typ]) # shape: [..., E, out_features] else: message = F.linear(signal, self.weight[typ], self.bias[typ]) # shape: [..., E, out_features] output = torch_scatter.scatter_add(message, target, dim=-2, dim_size=input.size(-2)) return output # shape: [..., N, out_features]
def segment_softmax_with_bias( x: torch.Tensor, bias: torch.Tensor, seg_ids: torch.Tensor, eps: float = 1e-6) -> t.Tuple[torch.Tensor, torch.Tensor]: """Segment softmax with bias Args: x (torch.Tensor): Input tensor, with shape [N, F] bias (torch.Tensor): Input bias, with shape [num_seg, ] seg_ids (torch.Tensor): Vector of size N eps (float): A small value for numerical stability Returns: tuple[torch.Tensor] """ # get shape information num_seg = bias.size(0) # The max trick # size: [N, F + 1] # pylint: disable=bad-continuation x_max: torch.Tensor = torch.cat( [x, bias.index_select(0, seg_ids).unsqueeze(-1)], dim=-1) # size: [N, ] x_max, _ = torch.max(x_max, dim=-1) # size: [num_seg, ] x_max, _ = torch_scatter.scatter_max(x_max, index=seg_ids, dim=0, dim_size=num_seg) x = x - x_max.index_select(0, seg_ids).unsqueeze(-1) bias = bias - x_max x_exp, bias_exp = torch.exp(x), torch.exp(bias) # shape: [num_seg, ] x_sum = torch_scatter.scatter_add(x_exp.sum(-1), dim=0, index=seg_ids, dim_size=num_seg) # shape: [num_seg, ] x_bias_sum = x_sum + bias_exp + eps # shape: [N, F] x_softmax = x_exp / x_bias_sum.index_select(0, seg_ids).unsqueeze(-1) # shape: [num_seg, ] bias_softmax = bias_exp / x_bias_sum return x_softmax, bias_softmax
def node_model(x, edge_index, edge_attr, u=None, v_indices=None): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] if self.independent: return self.node_mlp(x) row, col = edge_index if self.e2v_agg == "sum": out = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0)) elif self.e2v_agg == "mean": out = scatter_mean(edge_attr, row, dim=0, dim_size=x.size(0)) out = torch.cat([x, out, u[v_indices]], dim=1) return self.node_mlp(out)