def forward(self, x, e, o, edge_index):
     start, end = edge_index
     # Aggregate edge-weighted incoming/outgoing features
     mi = scatter_add(e[:, None] * x[start],
                      end,
                      dim=0,
                      dim_size=x.shape[0])
     mo = scatter_add(e[:, None] * x[end],
                      start,
                      dim=0,
                      dim_size=x.shape[0])
     global_i = scatter_add(torch.ger(e, o),
                            end,
                            dim=0,
                            dim_size=x.shape[0])
     global_o = scatter_add(torch.ger(e, o),
                            start,
                            dim=0,
                            dim_size=x.shape[0])
     #         print(mi.shape, mo.shape, global_i.shape, global_o.shape, x.shape, (torch.cat([mi, mo, global_i, global_o, x], dim=1)).shape)
     node_inputs = torch.cat([mi, mo, global_i, global_o, x], dim=1)
     return self.network(node_inputs)
Esempio n. 2
0
    def forward(self, x, batch, size=None):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        size = batch[-1].item() + 1 if size is None else size

        gate = self.gate_nn(x).view(-1, 1)
        x = self.nn(x) if self.nn is not None else x
        assert gate.dim() == x.dim() and gate.size(0) == x.size(0)

        gate = softmax(gate, batch, num_nodes=size)
        out = scatter_add(gate * x, batch, dim=0, dim_size=size)

        return out
    def forward(self, x, edge_index):

        # Encode the graph features into the hidden space
        input_x = x
        x = self.node_encoder(x)

        start, end = edge_index

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            # Previous hidden state
            #             x0 = x

            # Compute new edge score
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = checkpoint(self.edge_network, edge_inputs)
            e = torch.sigmoid(e)

            # Sum weighted node features coming into each node
            #             weighted_messages_in = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0])
            #             weighted_messages_out = scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])

            weighted_messages = scatter_add(
                e * x[start], end, dim=0, dim_size=x.shape[0]) + scatter_add(
                    e * x[end], start, dim=0, dim_size=x.shape[0])

            # Compute new node features
            #             node_inputs = torch.cat([x, weighted_messages_in, weighted_messages_out], dim=1)
            node_inputs = torch.cat([x, weighted_messages], dim=1)
            #             node_inputs = weighted_messages + x

            x = checkpoint(self.node_network, node_inputs)

            # Residual connection
        #             x = x + x0

        # Compute final edge scores; use original edge directions only
        clf_inputs = torch.cat([x[start], x[end]], dim=1)
        return checkpoint(self.edge_network, clf_inputs).squeeze(-1)
Esempio n. 4
0
    def colcount(self) -> torch.Tensor:
        colcount = self._colcount
        if colcount is not None:
            return colcount

        colptr = self._colptr
        if colptr is not None:
            colcount = colptr[1:] - colptr[:-1]
        else:
            colcount = scatter_add(torch.ones_like(self._col), self._col,
                                   dim_size=self._sparse_sizes[1])
        self._colcount = colcount
        return colcount
Esempio n. 5
0
    def forward(self, x, edge_index, pseudo):
        """"""
        # See https://github.com/shchur/gnn-benchmark for the reference
        # TensorFlow implementation.
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
        row, col = edge_index

        F, (E, D) = x.size(1), pseudo.size()

        gaussian = -0.5 * (pseudo.view(E, 1, D) - self.mu.view(1, F, D))**2
        gaussian = torch.exp(gaussian / (1e-14 + self.sigma.view(1, F, D)**2))
        gaussian = gaussian.prod(dim=-1)

        # Normalize gaussians in edge dimension.
        gaussian_mean = scatter_add(gaussian, row, dim=0, dim_size=x.size(0))
        gaussian = gaussian / (1e-14 + gaussian_mean[row]).view(E, F)

        out = scatter_add(x[col] * gaussian, row, dim=0, dim_size=x.size(0))
        out = self.lin(out)

        return out
Esempio n. 6
0
def i_and_u(pred, target, num_classes, batch=None):
    r"""Computes intersection and union of predictions.

    Args:
        pred (LongTensor): The predictions.
        target (LongTensor): The targets.
        num_classes (int): The number of classes.
        batch (LongTensor): The assignment vector which maps each pred-target
            pair to an example.

    :rtype: (:class:`LongTensor`, :class:`LongTensor`)
    """
    pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes)

    if batch is None:
        i = (pred & target).sum(dim=0)
        u = (pred | target).sum(dim=0)
    else:
        i = scatter_add(pred & target, batch, dim=0)
        u = scatter_add(pred | target, batch, dim=0)

    return i, u
Esempio n. 7
0
 def __init__(self, batch: GraphBatch):
     self._batch = batch
     self._pooling_functions = {
         'mean':
         lambda src, idx: torch_scatter.scatter_mean(
             src, idx, dim=0, dim_size=batch.num_graphs),
         'sum':
         lambda src, idx: torch_scatter.scatter_add(
             src, idx, dim=0, dim_size=batch.num_graphs),
         'max':
         lambda src, idx: torch_scatter.scatter_max(
             src, idx, dim=0, dim_size=batch.num_graphs)[0],
     }
    def forward(self, edges, vertices, target_idx):
        x = vertices
        x = torch.cat([x] + [torch_scatter.scatter_add(x[edges[:,1]], edges[:,0], dim=0, dim_size=vertices.size(0))], dim=1)

        x = self.l1(x)

        identity = x
        x = F.relu(self.l2(torch.cat([x] + [torch_scatter.scatter_add(x[edges[:,1]], edges[:,0], dim=0, dim_size=vertices.size(0))], dim=1)))
        x = x / (torch.norm(x, p=2, dim=1).unsqueeze(0).t() + 0.000001) 
        x += identity # residual connection

        x = self.dropout(F.relu(self.l3(torch.cat([x] + [torch_scatter.scatter_add(x[edges[:,1]], edges[:,0], dim=0, dim_size=vertices.size(0))], dim=1))))
        x = x / (torch.norm(x, p=2, dim=1).unsqueeze(0).t() + 0.000001)

        x_target = x[target_idx]
        x = torch.squeeze(x, dim=1)
        
        x_target = self.l4(x_target)
        
        x_target = torch.unsqueeze(x_target, dim=0)

        return x_target
Esempio n. 9
0
    def forward(self, x, edge_index, edge_attr, batch_mask):

        x = x @ self.weight
        x = self.norm(x, batch_mask)

        alpha, alpha_index = self.attention(x, edge_index, edge_attr)

        row, col = alpha_index
        num_nodes = x.size(0)

        deg = scatter_add(alpha.abs(), row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * alpha * deg_inv_sqrt[col]

        out = self.my_cast(norm, x[col])
        out = scatter_add(out, row, dim=0, dim_size=x.size(0))

        if self.bias is not None:
            out = out + self.bias

        return out, alpha, alpha_index
Esempio n. 10
0
    def __matmul__(self, node_signal: torch.Tensor) -> torch.Tensor:

        """
        product = input * weight
                = node_signal * W
        """
        assert node_signal.shape[0] == self.n_node
        assert self.edges is not None and self.edges.squeeze().dim()==1
        senders_features = node_signal[self.senders]
        broadcast_edges = self.edges.view(-1, *([1]* (node_signal.dim() -1)))
        weighted_senders = senders_feaures * broadcast_edges
        node_results = scatter_add(src= weighted_senders, index = self.receivers, dim=0, dim_size= self.n_node)
        return node_results
Esempio n. 11
0
    def forward(self, x, hyperedge_index, hyperedge_weight=None):
        r"""
        Args:
            x (Tensor): Node feature matrix :math:`\mathbf{X}`
            hyper_edge_index (LongTensor): Hyperedge indices from
                :math:`\mathbf{H}`.
            hyperedge_weight (Tensor, optional): Sparse hyperedge weights from
                :math:`\mathbf{W}`. (default: :obj:`None`)
        """
        x = torch.matmul(x, self.weight)
        alpha = None

        if self.use_attention:
            x = x.view(-1, self.heads, self.out_channels)
            x_i, x_j = x[hyperedge_index[0]], x[hyperedge_index[1]]
            alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
            alpha = F.leaky_relu(alpha, self.negative_slope)
            alpha = softmax(alpha, hyperedge_index[0], num_nodes=x.size(0))
            alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        if hyperedge_weight is None:
            D = degree(hyperedge_index[0], x.size(0), x.dtype)
        else:
            D = scatter_add(hyperedge_weight[hyperedge_index[1]],
                            hyperedge_index[0], dim=0, dim_size=x.size(0))
        D = 1.0 / D
        D[D == float("inf")] = 0

        if hyperedge_index.numel() == 0:
            num_edges = 0
        else:
            num_edges = hyperedge_index[1].max().item() + 1
        B = 1.0 / degree(hyperedge_index[1], num_edges, x.dtype)
        B[B == float("inf")] = 0
        if hyperedge_weight is not None:
            B = B * hyperedge_weight

        self.flow = 'source_to_target'
        out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha)
        self.flow = 'target_to_source'
        out = self.propagate(hyperedge_index, x=out, norm=D, alpha=alpha)

        if self.concat is True:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out = out + self.bias

        return out
Esempio n. 12
0
def group_data(data,
               cluster=None,
               unique_pos_indices=None,
               mode="last",
               skip_keys=[]):
    """ Group data based on indices in cluster.
    The option ``mode`` controls how data gets agregated within each cluster.

    Parameters
    ----------
    data : Data
        [description]
    cluster : torch.Tensor
        Tensor of the same size as the number of points in data. Each element is the cluster index of that point.
    unique_pos_indices : torch.tensor
        Tensor containing one index per cluster, this index will be used to select features and labels
    mode : str
        Option to select how the features and labels for each voxel is computed. Can be ``last`` or ``mean``.
        ``last`` selects the last point falling in a voxel as the representent, ``mean`` takes the average.
    skip_keys: list
        Keys of attributes to skip in the grouping
    """

    assert mode in ["mean", "last"]
    if mode == "mean" and cluster is None:
        raise ValueError(
            "In mean mode the cluster argument needs to be specified")
    if mode == "last" and unique_pos_indices is None:
        raise ValueError(
            "In last mode the unique_pos_indices argument needs to be specified"
        )

    num_nodes = data.num_nodes
    for key, item in data:
        if bool(re.search("edge", key)):
            raise ValueError("Edges not supported. Wrong data type.")
        if key in skip_keys:
            continue

        if torch.is_tensor(item) and item.size(0) == num_nodes:
            if mode == "last" or key == "batch" or key == SaveOriginalPosId.KEY:
                data[key] = item[unique_pos_indices]
            elif mode == "mean":
                if key == "y":
                    item_min = item.min()
                    item = F.one_hot(item - item_min)
                    item = scatter_add(item, cluster, dim=0)
                    data[key] = item.argmax(dim=-1) + item_min
                else:
                    data[key] = scatter_mean(item, cluster, dim=0)
    return data
Esempio n. 13
0
    def message(self, x_i, x_j, edge_index, num_nodes):
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index[0], None, num_nodes)

        if self.mod == "additive":
            ones = torch.ones_like(alpha)
            h = x_j * ones.view(-1, self.heads, 1)
            h = torch.mul(self.w, h)

            return x_j * alpha.view(-1, self.heads, 1) + h

        elif self.mod == "scaled":
            ones = alpha.new_ones(edge_index[0].size())
            degree = scatter_add(
                ones, edge_index[0],
                dim_size=num_nodes)[edge_index[0]].unsqueeze(-1)
            degree = torch.matmul(degree, self.l1) + self.b1
            degree = self.activation(degree)
            degree = torch.matmul(degree, self.l2) + self.b2
            degree = degree.unsqueeze(-2)

            return torch.mul(x_j * alpha.view(-1, self.heads, 1), degree)

        elif self.mod == "f-additive":
            alpha = torch.where(alpha > 0, alpha + 1, alpha)

        elif self.mod == "f-scaled":
            ones = alpha.new_ones(edge_index[0].size())
            degree = scatter_add(
                ones, edge_index[0],
                dim_size=num_nodes)[edge_index[0]].unsqueeze(-1)
            alpha = alpha * degree

        else:
            alpha = alpha  # origin

        return x_j * alpha.view(-1, self.heads, 1)
Esempio n. 14
0
    def norm(self, edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        with torch.no_grad(): 
            if edge_weight is None:
                edge_weight = torch.ones((edge_index.size(1), ),
                                        dtype=dtype,
                                        device=edge_index.device)
            edge_weight = edge_weight.view(-1)
            assert edge_weight.size(0) == edge_index.size(1)
            row, col = edge_index
            deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
            deg_inv_sqrt = deg.pow(-1)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return deg_inv_sqrt[row], deg_inv_sqrt[col]
Esempio n. 15
0
 def forward(
         self,
         x: torch.Tensor,  # num_records * 3
         cond: torch.Tensor,  # ragged
         seg_ids: torch.Tensor):  # size(num_records)
     # x = self.bn(x)
     x = x[seg_ids, ...]
     x = torch.cat((x, cond), dim=-1)
     x = F.relu(self.bn1(self.fc1(x)))
     x = scatter_add(x, seg_ids, dim=0)
     x = self.fc2(F.relu(self.bn2(x)))
     mu, var = torch.split(x, x.size(-1) // 2, -1)
     var = F.softplus(var) / math.log(2)
     return mu, var
Esempio n. 16
0
    def forward(self, x, edge_index):
        x_in = x
        edge_index, _ = add_remaining_self_loops(edge_index)

        if self.norm == 'dropedge':
            if self.training:
                edge_index, _ = dropout_adj(edge_index,
                                            force_undirected=True,
                                            training=True)
            else:
                edge_index, _ = dropout_adj(edge_index,
                                            force_undirected=True,
                                            training=False)

        row, col = edge_index
        deg = degree(row)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # x = self.linear(x)

        if self.norm == 'neighbornorm':
            x_j = self.normlayer(x, edge_index)
        else:
            x_j = x[col]

        x_j = norm.view(-1, 1) * x_j
        out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0))

        out = self.linear(out)

        if self.activation:
            out = F.relu(out)

        if self.norm == 'batchnorm':
            out = self.normlayer(out)
        elif self.norm == 'layernorm':
            out = self.normlayer(out)
        elif self.norm == 'pairnorm':
            out = self.normlayer(out)
        elif self.norm == 'nodenorm':
            out = self.normlayer(out)

        if self.residual:
            out = x_in + out

        if self.dropout:
            out = F.dropout(out, p=0.5, training=self.training)

        return out
Esempio n. 17
0
def calc_log_prob(init_prob, traj_prob, rep_init, rep_rows, final_samples):
    if rep_rows.shape[0]:
        zeros = torch.zeros(rep_init.shape[0], 1).to(init_prob.device)
        nonstop = sampler.pred_stop(rep_init, zeros)[0]
        nonstop = scatter_add(nonstop,
                              rep_rows,
                              dim=0,
                              dim_size=final_samples.shape[0])
    else:
        nonstop = 0
    ones = torch.ones(final_samples.shape[0], 1).to(init_prob.device)
    last_stop = sampler.pred_stop(final_samples, ones)[0]

    return init_prob + traj_prob + nonstop + last_stop
Esempio n. 18
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        # edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        # print(x.size(), edge_index.size())
        row, col = edge_index
        batch, num_nodes, num_edges, K = x.size(0), x.size(1), row.size(
            0), self.weight.size(0)

        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        ###degree matrix
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        # Compute normalized and rescaled Laplacian.
        deg = deg.pow(-0.5)
        deg[torch.isinf(deg)] = 0
        lap = -deg[row] * edge_weight * deg[col]
        ###Rescale the Laplacian eigenvalues in [-1, 1]
        ##rescale: 2L/lmax-I; lmax=1.0
        fill_value = -0.05  ##-0.5
        edge_index, lap = add_self_loops(edge_index, lap, fill_value,
                                         num_nodes)
        lap *= 2

        ########################################
        # Perform filter operation recurrently.
        Tx_0 = x
        out = torch.matmul(Tx_0, self.weight[0])
        if K > 1:
            Tx_1 = sparse_dense_mat_mul(
                edge_index, lap, num_nodes,
                x.permute(1, 2, 0).contiguous().view(
                    (num_nodes, -1))).view((num_nodes, -1, batch)).permute(
                        2, 0, 1
                    )  # sparse_dense_mat_mul(edge_index, lap, num_nodes, x)
            out = out + torch.matmul(Tx_1, self.weight[1])

        for k in range(2, K):
            Tx_2 = 2 * sparse_dense_mat_mul(
                edge_index, lap, num_nodes,
                x.permute(1, 2, 0).contiguous().view((num_nodes, -1))).view(
                    (num_nodes, -1, batch)).permute(2, 0, 1) - Tx_0
            # 2 * sparse_dense_mat_mul(edge_index, lap, num_nodes, Tx_1) - Tx_0
            out = out + torch.matmul(Tx_2, self.weight[k])
            Tx_0, Tx_1 = Tx_1, Tx_2

        if self.bias is not None:
            out = out + self.bias

        return out
Esempio n. 19
0
    def forward(self, data):
        fingerprint = torch.zeros((data.batch.shape[0], self.fp_size),
                                  dtype=torch.float)

        out = data.x
        print(type(data.edge_index))

        for idx, loop in enumerate(self.loops):
            updated_atom_features, updated_fingerprint = loop(
                out, data.edge_index)
            out = updated_atom_features
            fingerprint += updated_fingerprint

        return scatter_add(fingerprint, data.batch, dim=0)
Esempio n. 20
0
    def forward(self, x, edge_index, edge_weight=None, size=None):
        """"""
        num_nodes = x.shape[0]
        h = torch.matmul(x, self.weight)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=x.dtype,
                                     device=edge_index.device)
        edge_index, edge_weight = remove_self_loops(edge_index=edge_index,
                                                    edge_attr=edge_weight)
        deg = scatter_add(edge_weight,
                          edge_index[0],
                          dim=0,
                          dim_size=num_nodes)  #+ 1e-10

        h_j = edge_weight.view(-1, 1) * h[edge_index[1]]
        aggr_out = scatter_add(h_j, edge_index[0], dim=0, dim_size=num_nodes)
        out = (deg.view(-1, 1) * self.lin1(x) + aggr_out) + self.lin2(x)
        edge_index, edge_weight = add_self_loops(edge_index=edge_index,
                                                 edge_weight=edge_weight,
                                                 num_nodes=num_nodes)
        return out
Esempio n. 21
0
    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index))
        data.x = F.elu(self.conv2(data.x, data.edge_index))
        data.x = F.elu(self.conv3(data.x, data.edge_index))
        x_1 = scatter_add(data.x, data.batch, dim=0)
        x = x_1

        if args.no_train:
            x = x.detach()

        x = F.elu(self.fc1(x))
        x = F.elu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)
 def __init__(self, graph: _BaseGraph):
     self._graph = graph
     # TODO move these to the class definition or somewhere else
     self._pooling_functions = {
         'mean':
         lambda src, idx: torch_scatter.scatter_mean(
             src, idx, dim=0, dim_size=graph.num_nodes),
         'sum':
         lambda src, idx: torch_scatter.scatter_add(
             src, idx, dim=0, dim_size=graph.num_nodes),
         'max':
         lambda src, idx: torch_scatter.scatter_max(
             src, idx, dim=0, dim_size=graph.num_nodes)[0],
     }
Esempio n. 23
0
 def forward(self, graphs: tg.GraphBatch):
     nodes = F.relu(self.g_n(graphs.node_features))
     globals = self.h_n(
         torch_scatter.scatter_add(nodes,
                                   segment_lengths_to_ids(
                                       graphs.num_nodes_by_graph),
                                   dim=0,
                                   dim_size=graphs.num_graphs))
     return graphs.evolve(num_edges=0,
                          edge_features=None,
                          node_features=None,
                          global_features=globals,
                          senders=None,
                          receivers=None)
Esempio n. 24
0
 def forward(self, x, edge_index, edge_attr, u, history_vector,
             batch):
     gate = self.node_mlp_1(x)
     assert gate.dim() == x.dim() and gate.size(0) == x.size(0)
     # gate = torch.bmm(x.unsqueeze(1) , self.ques_nn(u)[batch].unsqueeze(2)).squeeze(-1)
     # assert gate.dim() == x.dim() and gate.size(0) == x.size(0)
     gate = torch_geometric.utils.softmax(gate,
                                          batch,
                                          num_nodes=None)
     new_history_vector = scatter_add(gate * x,
                                      batch,
                                      dim=0,
                                      dim_size=None)
     return gate, new_history_vector
Esempio n. 25
0
def forward(f, shapes, lmax, device):
    r_max = 1.1
    x = torch.ones(4, 1)
    batch = Batch.from_data_list([
        DataNeighbors(x, shape, r_max, self_interaction=False)
        for shape in shapes
    ])
    batch = batch.to(device)
    sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr,
                                     'component')
    out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3)
    out = scatter_add(out, batch.batch, dim=0)
    out = torch.tanh(out)
    return out
Esempio n. 26
0
    def _sofmax(indexes, egde_values):
        """

        :param indexes: nodes of each edge
        :param egde_values: values of each edge
        :return: normalized values of edges considering nodes
        """
        edge_values = torch.exp(egde_values)

        row_sum = scatter_add(edge_values, indexes, dim=0)

        edge_softmax = edge_values / row_sum[indexes, :, :]

        return edge_softmax
Esempio n. 27
0
    def forward(self, x, edge_index, edge_attr,
                autoregressive_x=None, node_mask=None):
        '''
        :param x: tuple (s, V) of `torch.Tensor`
        :param edge_index: array of shape [2, n_edges]
        :param edge_attr: tuple (s, V) of `torch.Tensor`
        :param autoregressive_x: tuple (s, V) of `torch.Tensor`. 
                If not `None`, will be used as srcqq node embeddings
                for forming messages where src >= dst. The corrent node 
                embeddings `x` will still be the base of the update and the 
                pointwise feedforward.
        :param node_mask: array of type `bool` to index into the first
                dim of node embeddings (s, V). If not `None`, only
                these nodes will be updated.
        '''
        
        if autoregressive_x is not None:
            src, dst = edge_index
            mask = src < dst
            edge_index_forward = edge_index[:, mask]
            edge_index_backward = edge_index[:, ~mask]
            edge_attr_forward = tuple_index(edge_attr, mask)
            edge_attr_backward = tuple_index(edge_attr, ~mask)
            
            dh = tuple_sum(
                self.conv(x, edge_index_forward, edge_attr_forward),
                self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
            )
            
            count = scatter_add(torch.ones_like(dst), dst,
                        dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
            
            dh = dh[0] / count, dh[1] / count.unsqueeze(-1)

        else:
            dh = self.conv(x, edge_index, edge_attr)
        
        if node_mask is not None:
            x_ = x
            x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
            
        x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
        
        dh = self.ff_func(x)
        x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
        
        if node_mask is not None:
            x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
            x = x_
        return x
Esempio n. 28
0
 def propagate_homo(self, graph, input, typ):
     [source, target] = graph
     signal = input[..., source, :]  # shape [..., E, in_features]
     if self.bias is None:
         message = F.linear(
             signal, self.weight[typ])  # shape: [..., E, out_features]
     else:
         message = F.linear(signal, self.weight[typ],
                            self.bias[typ])  # shape: [..., E, out_features]
     output = torch_scatter.scatter_add(message,
                                        target,
                                        dim=-2,
                                        dim_size=input.size(-2))
     return output  # shape: [..., N, out_features]
Esempio n. 29
0
def segment_softmax_with_bias(
        x: torch.Tensor,
        bias: torch.Tensor,
        seg_ids: torch.Tensor,
        eps: float = 1e-6) -> t.Tuple[torch.Tensor, torch.Tensor]:
    """Segment softmax with bias

    Args:
        x (torch.Tensor): Input tensor, with shape [N, F]
        bias (torch.Tensor): Input bias, with shape [num_seg, ]
        seg_ids (torch.Tensor): Vector of size N
        eps (float): A small value for numerical stability

    Returns:
        tuple[torch.Tensor]
    """

    # get shape information
    num_seg = bias.size(0)

    # The max trick
    # size: [N, F + 1]
    # pylint: disable=bad-continuation
    x_max: torch.Tensor = torch.cat(
        [x, bias.index_select(0, seg_ids).unsqueeze(-1)], dim=-1)
    # size: [N, ]
    x_max, _ = torch.max(x_max, dim=-1)
    # size: [num_seg, ]
    x_max, _ = torch_scatter.scatter_max(x_max,
                                         index=seg_ids,
                                         dim=0,
                                         dim_size=num_seg)

    x = x - x_max.index_select(0, seg_ids).unsqueeze(-1)
    bias = bias - x_max

    x_exp, bias_exp = torch.exp(x), torch.exp(bias)
    # shape: [num_seg, ]
    x_sum = torch_scatter.scatter_add(x_exp.sum(-1),
                                      dim=0,
                                      index=seg_ids,
                                      dim_size=num_seg)
    # shape: [num_seg, ]
    x_bias_sum = x_sum + bias_exp + eps
    # shape: [N, F]
    x_softmax = x_exp / x_bias_sum.index_select(0, seg_ids).unsqueeze(-1)
    # shape: [num_seg, ]
    bias_softmax = bias_exp / x_bias_sum

    return x_softmax, bias_softmax
Esempio n. 30
0
        def node_model(x, edge_index, edge_attr, u=None, v_indices=None):
            # x: [N, F_x], where N is the number of nodes.
            # edge_index: [2, E] with max entry N - 1.
            # edge_attr: [E, F_e]
            if self.independent:
                return self.node_mlp(x)

            row, col = edge_index
            if self.e2v_agg == "sum":
                out = scatter_add(edge_attr, row, dim=0, dim_size=x.size(0))
            elif self.e2v_agg == "mean":
                out = scatter_mean(edge_attr, row, dim=0, dim_size=x.size(0))
            out = torch.cat([x, out, u[v_indices]], dim=1)
            return self.node_mlp(out)