def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############
            #             start = time()
            #             graph_ids, graph_node_counts = batch.unique(return_counts=True)
            #             print(self.FC_edge_index)
            #             edge_index = self.FC_edge_index.FC_edge_index(graph_node_counts)
            #             print(time()-start)

            CoC = scatter_sum(x[:, -3:] * x[:, -5].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, -5].view(-1, 1), batch, dim=0)
            CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)

            for conv, scatter_norm in zip(self.convs, self.scatter_norms):
                x = conv(x, edge_index)
                CoC = torch.cat([CoC, scatter_norm(x, batch)], dim=1)

            CoC = self.decoder(CoC)

            return CoC
Esempio n. 2
0
    def test_qtensor_scatter_idx(self):

        row_ids = 1024
        idx = torch.randint(low=0,
                            high=256,
                            size=(row_ids, ),
                            dtype=torch.int64)
        p = 64
        x = QTensor(*torch.randn(4, row_ids, p))

        x_tensor = x.stack(dim=1)
        assert x_tensor.size() == torch.Size([row_ids, 4, p])

        x_aggr = scatter_sum(src=x_tensor,
                             index=idx,
                             dim=0,
                             dim_size=x_tensor.size(0))

        assert x_aggr.size() == x_tensor.size()
        x_aggr = x_aggr.permute(1, 0, 2)
        q_aggr = QTensor(*x_aggr)

        r = scatter_sum(x.r, idx, dim=0, dim_size=x.size(0))
        i = scatter_sum(x.i, idx, dim=0, dim_size=x.size(0))
        j = scatter_sum(x.j, idx, dim=0, dim_size=x.size(0))
        k = scatter_sum(x.k, idx, dim=0, dim_size=x.size(0))
        q_aggr2 = QTensor(r, i, j, k)

        assert q_aggr == q_aggr2
def contrastive_loss(encoder_output, graph_data, sim_metric):
    es, ps = encoder_output
    e_size = graph_data.x[0].size(0)

    ee_pos = graph_data.node_pos_index
    ee_neg = _contrastive_sample(ee_pos.size(1), graph_data.node_neg_index)
    ep_pos = graph_data.edge_pos_index
    ep_neg = _contrastive_sample(ep_pos.size(1), graph_data.edge_neg_index)

    ep1, ep2 = es.index_select(0, ee_pos[0]), es.index_select(0, ee_pos[1])
    link_sim = sim_metric(ep1, ep2, flatten=True, method='exp')
    en1, en2 = es.index_select(0, ee_neg[0]), es.index_select(0, ee_neg[1])
    non_sim = sim_metric(en1, en2, flatten=True, method='exp')
    pp1, pp2 = es.index_select(0, ep_pos[0]), ps.index_select(0, ep_pos[1])
    pos_sim = sim_metric(pp1, pp2, flatten=True, method='exp')
    pn1, pn2 = es.index_select(0, ep_neg[0]), ps.index_select(0, ep_neg[1])
    neg_sim = sim_metric(pn1, pn2, flatten=True, method='exp')

    en_sum = scatter_sum(non_sim, ee_neg[0], dim=-1, dim_size=e_size)
    link_loss = link_sim / (link_sim + en_sum.index_select(0, ee_pos[0]))
    ep_sum = scatter_sum(neg_sim, ep_neg[0], dim=-1, dim_size=e_size)
    ep_loss = pos_sim / (pos_sim + ep_sum.index_select(0, ep_pos[0]))

    loss = torch.cat([-link_loss.log(), -ep_loss.log()], dim=-1).mean()
    return loss
Esempio n. 4
0
    def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                            out: Optional[torch.Tensor] = None,
                            dim_size: Optional[int] = None,
                            unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        summ = tmp.clone()
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([summ,mean,var,maximum,minimum],dim=1)
Esempio n. 5
0
def scatter_std(src: torch.Tensor,
                index: torch.Tensor,
                dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None,
                unbiased: bool = True) -> torch.Tensor:

    if out is not None:
        dim_size = out.size(dim)

    if dim < 0:
        dim = src.dim() + dim

    count_dim = dim
    if index.dim() <= dim:
        count_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

    index = broadcast(index, src, dim)
    tmp = scatter_sum(src, index, dim, dim_size=dim_size)
    count = broadcast(count, tmp, dim).clamp(1)
    mean = tmp.div(count)

    var = (src - mean.gather(dim, index))
    var = var * var
    out = scatter_sum(var, index, dim, out, dim_size)

    if unbiased:
        count = count.sub(1).clamp_(1)
    out = out.div(count).sqrt()

    return out
        def forward(self, *data):
            if type(data) == tuple:
                from torch_geometric.data import Data, Batch
                datalist = []
                for x in data:
                    datalist.append(Data(x=x))
                data = Batch.from_data_list(datalist)
            try:
                x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            except:
                x = data
                batch = torch.zeros(x.shape[0],
                                    device=model.device,
                                    dtype=torch.int64)
            ###############
            x = x.float()
            ###############

            CoC = scatter_sum(x[:, -3:] * x[:, -5].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, -5].view(-1, 1), batch, dim=0)
            CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)

            h = torch.zeros((x.shape[0], self.hcs), device=self.device)
            for conv in self.convs:
                x, CoC, h = conv(x, CoC, h, batch)

            CoC = torch.cat([CoC, self.scatter_norm2(x, batch)], dim=1)

            CoC = self.decoder(CoC)

            return CoC
Esempio n. 7
0
def inverse_eig(X,edge_index,edge_weight,batch):
    with torch.no_grad():
        Z = torch.ones((X.shape[0],1)).cuda()
        for _ in range(20):
            Z = torch_scatter.scatter_sum(edge_weight[:,None] * Z[edge_index[1]], edge_index[0],dim=0)/10
        Z = Z/torch_scatter.scatter_sum(Z**2,batch,dim=0).sqrt()[batch]
        return 1/(1e-4 + Z)
        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            pos = x[:, -3:]

            x = torch.cat(
                [x, scatter_distribution(edge_attr, edge_index[1], dim=0)],
                dim=1)

            CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, 0].view(-1, 1), batch, dim=0)
            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            CoC_edge_index = torch.cat([
                torch.arange(x.shape[0]).view(1, -1).type_as(batch),
                batch.view(1, -1)
            ],
                                       dim=0)

            cart = pos[CoC_edge_index[0], -3:] - CoC[CoC_edge_index[1], :3]
            del pos

            rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]

            CoC_edge_attr = torch.cat(
                [cart.type_as(x),
                 rho.type_as(x), x[CoC_edge_index[0]]], dim=1)

            x = self.act(self.x_encoder(x))
            edge_attr = self.act(self.edge_attr_encoder(edge_attr))
            CoC = self.act(self.CoC_encoder(CoC))

            #             u = torch.zeros( (batch.max() + 1, self.hcs) ).type_as(x)
            h = torch.zeros((x.shape[0], self.hcs)).type_as(x)

            for i, op in enumerate(self.ops):
                x, edge_attr, CoC = op(x, edge_index, edge_attr, CoC, batch)
                h = self.act(self.GRUCells[i](torch.cat([
                    CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr
                ],
                                                        dim=1), h))
                CoC = self.act(self.lins1[i](torch.cat(
                    [CoC, scatter_distribution(h, batch, dim=0)], dim=1)))
                CoC = self.act(self.lins2[i](CoC))
                h = self.act(self.GRUCells[i](torch.cat([
                    CoC[CoC_edge_index[1]], x[CoC_edge_index[0]], CoC_edge_attr
                ],
                                                        dim=1), h))
                x = self.act(self.lins3[i](torch.cat([x, h], dim=1)))

            for lin in self.decoders:
                CoC = self.act(lin(CoC))

            CoC = self.decoder(CoC)
            return CoC
        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############
            pos = x[:, -3:]

            graph_ids, graph_node_counts = batch.unique(return_counts=True)

            time_edge_index = time_edge_indeces(x[:, 1], batch)

            edge_attr = edge_feature_constructor(x, time_edge_index)

            # Define central nodes at Center of Charge:
            CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, 0].view(-1, 1), batch, dim=0)

            # Define edge_attr for those edges:
            cart = pos[:, -3:] - CoC[batch, :3]
            del pos
            rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1)

            x = torch.cat([
                x,
                x_feature_constructor(x, graph_node_counts), edge_attr,
                x[time_edge_index[0]], CoC_edge_attr, CoC[batch]
            ],
                          dim=1)

            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)

            h = torch.zeros((x.shape[0], self.hcs)).type_as(x)

            for i in range(N_metalayers):
                x, CoC, h = self.convs[i](x, CoC, h, batch)

            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            CoC = self.decoder(CoC)

            #             out = []
            #             for mlp in self.decoders:
            #                 out.append(mlp(CoC))
            #             CoC = torch.cat(out,dim=1)

            return CoC
Esempio n. 10
0
 def forward(self, x, x_clique, tree_edge_index, atom2clique_index, u,
             tree_batch):
     row, col = tree_edge_index
     out = scatter_sum(x_clique[row], col, dim=0, dim_size=x_clique.size(0))
     out = self.mlp1(out)
     row_assign, col_assign = atom2clique_index
     node_info = scatter_sum(x[row_assign],
                             col_assign,
                             dim=0,
                             dim_size=x_clique.size(0))
     node_info = self.mlp2(node_info)  ### Step 4
     out = torch.cat([node_info, x_clique, out, u[tree_batch]], dim=1)
     return self.subgraph_mlp(out)  ### Step 5
        def forward(self, *data):
            #             print(data)
            #             print("================================")
            #             Print("Here")
            if type(data) == tuple:
                #                 print("here", len(data),data[0].shape)
                from torch_geometric.data import Data, Batch
                datalist = []
                for x in data:
                    if x.dim() > 2:
                        for tmp_x in x:
                            datalist.append(Data(x=tmp_x.squeeze()))
                    else:
                        datalist.append(Data(x=x.squeeze()))
#                     datalist.append(Data(x=x))
                data = Batch.from_data_list(datalist)
            try:
                x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            except AttributeError:
                x = data
                batch = torch.zeros(x.shape[0],
                                    device=model.device,
                                    dtype=torch.int64)
#             Print("To Here")
###############
            x = x.float()
            ###############
            #             print(x.shape)
            #             print(x)

            CoC = scatter_sum(x[:, -3:] * x[:, -5].unsqueeze(-1), batch,
                              dim=0) / scatter_sum(
                                  x[:, -5].unsqueeze(-1), batch, dim=0)
            CoC[CoC.isnan()] = 0
            CoC = torch.cat([CoC, self.scatter_norm(x, batch)], dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)

            h = torch.zeros((x.shape[0], self.hcs), device=self.device)
            out = CoC.clone()  #version a
            for conv in self.convs:
                x, CoC, h = conv(x, CoC, h, batch)
                out = torch.cat([out, CoC.clone()], dim=1)  #version a

            # CoC = torch.cat([CoC,self.scatter_norm2(x, batch, CoC)],dim=1)
            CoC = torch.cat([out, self.scatter_norm2(x, batch, out)],
                            dim=1)  #version a

            CoC = self.decoder(CoC)
            return CoC
Esempio n. 12
0
        def return_CoC_and_edge_attr(self, x, batch):
            pos = x[:, -3:]
            charge = x[:, 0].view(-1, 1)

            # Define central nodes at Center of Charge:
            CoC = scatter_sum(pos * charge, batch, dim=0) / scatter_sum(
                charge, batch, dim=0)

            # Define edge_attr for those edges:
            cart = pos - CoC[batch]
            rho = torch.norm(cart, p=2, dim=1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1)
            return CoC, CoC_edge_attr
Esempio n. 13
0
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        src_minus_mean = (src - mean.gather(dim, index))
        var = src_minus_mean * src_minus_mean
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)

        skew = src_minus_mean * src_minus_mean * src_minus_mean / (
            var.gather(dim, index) + 1e-7)**(1.5)
        kurtosis = (src_minus_mean * src_minus_mean * src_minus_mean *
                    src_minus_mean) / (var * var + 1e-7).gather(dim, index)

        skew = scatter_sum(skew, index, dim, out, dim_size)
        kurtosis = scatter_sum(kurtosis, index, dim, out, dim_size)

        skew = skew.div(count)
        kurtosis = kurtosis.div(count)

        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, skew, kurtosis, maximum, minimum], dim=1)
Esempio n. 14
0
def scatter_mul(src, edge_index, edge_attr=None, dim=0):
    scatter_src = src.index_select(dim, edge_index[0])
    if edge_attr is not None:
        assert edge_index.size(1) == edge_attr.size(0)
        scatter_src = scatter_src * edge_attr.long()
    output = scatter_sum(scatter_src, edge_index[1], dim)
    return output
Esempio n. 15
0
def scatter_logsumexp(src: torch.Tensor,
                      index: torch.Tensor,
                      dim: int = -1,
                      out: Optional[torch.Tensor] = None,
                      dim_size: Optional[int] = None,
                      eps: float = 1e-12) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_logsumexp` can only be computed over '
                         'tensors with floating point data types.')

    index = broadcast(index, src, dim)

    if out is not None:
        dim_size = out.size(dim)
    else:
        if dim_size is None:
            dim_size = int(index.max()) + 1

    size = src.size()
    size[dim] = dim_size
    max_value_per_index = torch.full(size,
                                     float('-inf'),
                                     dtype=src.dtype,
                                     device=src.device)
    scatter_max(src, index, dim, max_value_per_index, dim_size)[0]
    max_per_src_element = max_value_per_index.gather(dim, index)
    recentered_scores = src - max_per_src_element

    if out is not None:
        out = out.sub_(max_per_src_element).exp_()

    sum_per_index = scatter_sum(recentered_scores.exp_(), index, dim, out,
                                dim_size)

    return sum_per_index.add_(eps).log_().add_(max_value_per_index)
Esempio n. 16
0
 def forward(self, data):
     x = scatter_sum(data.h, data.batch, dim=0)
     if 'vn_h' in data:
         x += data.vn_h
     data.vn_h = self.mlp(x)
     data.h += data.vn_h[data.batch]
     return data
Esempio n. 17
0
def weighted_dimwise_median(A: torch.sparse.FloatTensor, x: torch.Tensor,
                            **kwargs) -> torch.Tensor:
    """A weighted dimension-wise Median aggregation.

    Parameters
    ----------
    A : torch.sparse.FloatTensor
        Sparse [n, n] tensor of the weighted/normalized adjacency matrix
    x : torch.Tensor
        Dense [n, d] tensor containing the node attributes/embeddings

    Returns
    -------
    torch.Tensor
        The new embeddings [n, d]
    """
    if not A.is_cuda:
        return weighted_dimwise_median_cpu(A, x, **kwargs)

    assert A.is_sparse
    N, D = x.shape

    median_idx = custom_cuda_kernels.dimmedian_idx(x, A)
    col_idx = torch.arange(D, device=A.device).view(1, -1).expand(N, D)
    x_selected = x[median_idx, col_idx]

    a_row_sum = torch_scatter.scatter_sum(A._values(), A._indices()[0],
                                          dim=-1).view(-1, 1).expand(N, D)
    return a_row_sum * x_selected
Esempio n. 18
0
 def forward(self, x, edge_index, edge_attr, u, batch):
     row, col = edge_index
     out = torch.cat([x[row], edge_attr], dim=1)
     out = scatter_sum(out, col, dim=0, dim_size=x.size(0))
     out = self.node_mlp_1(out)  ### Step 2
     out = torch.cat([x, out, u[batch]], dim=1)
     return self.node_mlp_2(out)  ### Step 3
Esempio n. 19
0
    def forward(self, inputs: ElementsToSummaryRepresentationInput) -> torch.Tensor:
        query = self.__query_layer(inputs)  # [num_graphs, H]
        query_per_node = query[inputs.element_to_sample_map]  # [num_vertices, H]
        values = self.__value_layer(inputs.element_embeddings)  # [num_vertices, H]

        query_per_node = values.reshape(
            (query_per_node.shape[0], self.__num_heads, query_per_node.shape[1] // self.__num_heads)
        )
        values = values.reshape(
            (values.shape[0], self.__num_heads, values.shape[1] // self.__num_heads)
        )

        attention_scores = torch.einsum(
            "vkh,vkh->vk", query_per_node, values
        )  # [num_vertices, num_heads]
        attention_probs = torch.exp(
            scatter_log_softmax(attention_scores, index=inputs.element_to_sample_map, dim=0, eps=0)
        )  # [num_vertices, num_heads]

        outputs = attention_probs.unsqueeze(-1) * inputs.element_embeddings.unsqueeze(
            1
        )  # [num_vertices, num_heads, D']
        per_graph_outputs = scatter_sum(
            outputs, index=inputs.element_to_sample_map, dim=0, dim_size=inputs.num_samples
        )  # [num_graphs, num_heads, D']
        per_graph_outputs = per_graph_outputs.reshape(
            (per_graph_outputs.shape[0], -1)
        )  # [num_graphs, num_heads * D']

        return self.__output_layer(per_graph_outputs)  # [num_graphs, D']
Esempio n. 20
0
 def forward(self, inputs: ElementsToSummaryRepresentationInput) -> torch.Tensor:
     weights = torch.sigmoid(
         self.__weights_layer(inputs.element_embeddings).squeeze(-1)
     )  # [num_vertices]
     return scatter_sum(
         inputs.element_embeddings * weights.unsqueeze(-1),
         index=inputs.element_to_sample_map,
         dim=0,
         dim_size=inputs.num_samples,
     )  # [num_graphs, D']
Esempio n. 21
0
def scatter_softmax(
    src: Tensor, index: Tensor, dim: int, dim_size: Optional[int] = None
) -> Tensor:
    if src.numel() == 0:
        return src
    slice_tuple = (slice(None),) * dim + (index,)
    expand_args = src.size()[:dim] + (-1,)
    src = src - scatter_max(src, index, dim, dim_size=dim_size)[0][slice_tuple]
    exp = torch.exp(src)
    return exp / scatter_sum(exp, index, dim, dim_size=dim_size)[slice_tuple]
Esempio n. 22
0
 def aggregate(self,
               inputs: torch.Tensor,
               index: torch.Tensor,
               dim_size: Optional[int] = None) -> torch.Tensor:
     out = scatter_softmax(inputs * self.beta, index, dim=self.node_dim)
     out = scatter_sum(inputs * out,
                       index,
                       dim=self.node_dim,
                       dim_size=dim_size)
     return out
Esempio n. 23
0
 def forward(self, x, edge_index, edge_attr, batch):
     # x: [N, h], where N is the number of nodes.
     # edge_index: [2, E] with max entry N - 1.
     # edge_attr: [E, F_e]
     # u: [B, F_u] (N/A)
     # batch: [N] with max entry B - 1.
     # source, target = edge_index
     _, col = edge_index
     out = self.node_mlp_1(edge_attr)
     out = scatter_sum(out, col, dim=0, dim_size=x.size(0))
     return self.node_mlp_2(out)
Esempio n. 24
0
    def forward(self,X,edge_index,edge_weight,batch):
        X = self.start(X)
        for idx,m in enumerate(self.intermediate):
            Update = m[0](X) + torch_scatter.scatter_sum(edge_weight[:,None] * m[1](X)[edge_index[1]], edge_index[0],dim=0)

            if self.res: X = X + Update
            else: X = Update

            X = torch.nn.LeakyReLU()(self.norm[idx](X,edge_index,edge_weight,batch))
            if torch.isnan(X).any(): raise ValueError
        return self.finish(X)
Esempio n. 25
0
 def forward(self,X,edge_index,edge_weight,batch):
     # Project to int_channels
     X = self.start(X)
     
     # Run through GraphConv layers
     for idx,m in enumerate(self.intermediate):
         X = m[0](X) + torch_scatter.scatter_sum(edge_weight[:,None] * m[1](X)[edge_index[1]], edge_index[0],dim=0)
         X = torch.nn.LeakyReLU()(self.bn[idx](X))
         
     # Project to out_channels
     return self.finish(X)
Esempio n. 26
0
        def forward(self,data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############

            CoC = scatter_sum( x[:,-3:]*x[:,-5].view(-1,1), batch, dim=0) / scatter_sum(x[:,-5].view(-1,1), batch, dim=0)
            CoC = torch.cat([CoC,self.scatter_norm(x,batch)],dim=1)

            x = self.x_encoder(x)
            CoC = self.CoC_encoder(CoC)
            
            CoC_x = torch.cat([CoC,x],dim=0)
            
            edge_index = self.return_edge_index(batch)
            
            CoC_x = self.TConv(CoC_x, edge_index)

            CoC = self.decoder(CoC_x[batch.unique()])

            return CoC
Esempio n. 27
0
    def forward(self, data):
        # device = self.device
        # mode   = self.mode
        k = self.k
        device = self.device
        pos_idx = self.pos_idx
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_index = knn_graph(x=x[:, pos_idx], k=k, batch=batch).to(device)
        x = self.GGconv1(x, edge_index)
        x = self.relu(x)

        x = self.nn1(x)
        x = self.relu(x)

        y = self.resblock1(x)
        x = x + y

        z = self.resblock2(x)
        x = x + z

        del y, z

        x = self.nn2(x)
        x = self.relu(x)

        x = self.GGconv2(x, edge_index)
        x = self.relu(x)

        p = self.resblock3(x)
        x = x + p

        o = self.resblock4(x)
        x = x + o
        del p, o

        x = self.nn3(x)
        x = self.relu(x)

        a, _ = scatter_max(x, batch, dim=0)
        b, _ = scatter_min(x, batch, dim=0)
        c = scatter_sum(x, batch, dim=0)
        d = scatter_mean(x, batch, dim=0)
        x = torch.cat((a, b, c, d), dim=1)
        # print ("cat size",x.size())
        del a, b, c, d

        x = self.nncat(x)
        x = self.relu(x)
        # if(torch.sum(torch.isnan(x)) != 0):
        # print('NAN ENCOUNTERED AT NN2')

        # print ("xsize %s batchsize %s a size %s b size %s y size %s end forward" %(x.size(),batch.size(),a.size(),b.size(),data.y[:,0].size()))
        return x
Esempio n. 28
0
    def test_scatter_batch_idx(self):

        n_graphs = 128
        n_nodes = 2048
        idx = torch.randint(low=0,
                            high=n_graphs,
                            size=(n_nodes, ),
                            dtype=torch.int64)
        p = 64
        x = QTensor(*torch.randn(4, n_nodes, p))

        x_tensor = x.stack(dim=1)
        assert x_tensor.size() == torch.Size([n_nodes, 4, p])

        x_aggr = scatter_sum(src=x_tensor, index=idx, dim=0)
        x_aggr2 = global_add_pool(x_tensor, batch=idx)
        assert torch.allclose(x_aggr, x_aggr2)

        x_aggr = x_aggr.permute(1, 0, 2)
        q_aggr = QTensor(*x_aggr)

        r = scatter_sum(x.r, idx, dim=0)
        i = scatter_sum(x.i, idx, dim=0)
        j = scatter_sum(x.j, idx, dim=0)
        k = scatter_sum(x.k, idx, dim=0)
        q_aggr2 = QTensor(r, i, j, k)

        assert q_aggr == q_aggr2
        assert torch.allclose(x_aggr[0], r)
        assert torch.allclose(x_aggr[1], i)
        assert torch.allclose(x_aggr[2], j)
        assert torch.allclose(x_aggr[3], k)

        r1 = global_add_pool(x.r, idx)
        i1 = global_add_pool(x.i, idx)
        j1 = global_add_pool(x.j, idx)
        k1 = global_add_pool(x.k, idx)
        q_aggr3 = QTensor(r1, i1, j1, k1)

        assert q_aggr == q_aggr2 == q_aggr3
Esempio n. 29
0
def get_metrics(model,test_loader,k=64):

  # Layerwise MAD
  model_mad = []
    
  # Layerwise AggNorm
  model_agg = []
    
  # Normalized Rayleigh
  model_ray = []
    
  # S1 parameter
  model_s = []

  MAD,Agg,Ray,S = torch.zeros(k),torch.zeros(k),torch.zeros(k),torch.zeros(k)
    
  # Iterate over dataset and average metrics
  for idx,data in enumerate(test_loader):
        X = data.x.cuda()
        edge_index,edge_weight = data.edge_index.cuda(),data.edge_weight.cuda()
        batch = data.batch.cuda()

        model.eval()
        X = model.start(X)
        
        # Iterate over model layers
        for jdx,m in enumerate(model.intermediate):
            Update = m[0](X) + torch_scatter.scatter_sum(edge_weight[:,None] * m[1](X)[edge_index[1]], edge_index[0],dim=0)

            if model.res: X = X + Update
            else: X = Update
            X = torch.nn.LeakyReLU()(model.norm[idx](X,edge_index,edge_weight,batch))

            # Fetch S1
            S[jdx] += torch.sigmoid(model.norm[jdx].s1).item()
            
            # Compute MAD
            MAD[jdx] += batched_MAD(X,data.edge_index.cuda(),data.edge_weight.cuda()).mean().item()
            
            # Compute AggNorm
            Agg[jdx] += batched_agg(X,data.edge_index.cuda(),data.edge_weight.cuda(),batch).item()
            
            #Compute Normalized Rayleigh
            Ray[jdx] += rayleigh_quotient(X,edge_index,edge_weight,batch,data.eig_max.cuda(),data.eig_min.cuda()).mean().item()
        
  model_mad.append(MAD/(idx+1))
  model_agg.append(Agg/(idx+1))
  model_ray.append(Ray/(idx+1))
  model_s.append(S/(idx+1))

  # Return metrics
  return (model_mad,model_agg,model_ray,model_s)
Esempio n. 30
0
def scatter_mean(
        src: torch.Tensor,
        index: torch.Tensor,
        dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count_ret = count.clone()
    count.clamp_(1)
    count = broadcast(count, out, dim)
    out.div_(count)
    return out, count_ret