Ejemplo n.º 1
0
    def aggregate(self, inputs, index, ptr=None, dim_size=None):

        if self.aggr in ['add', 'mean', 'max', None]:
            return super(GENConv, self).aggregate(inputs, index, ptr, dim_size)

        elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']:

            inputs = inputs * self.edge_mask1_train * self.edge_mask2_fixed
            if self.learn_t:
                out = scatter_softmax(inputs * self.t,
                                      index,
                                      dim=self.node_dim)
            else:
                with torch.no_grad():
                    out = scatter_softmax(inputs * self.t,
                                          index,
                                          dim=self.node_dim)

            out = scatter(inputs * out,
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='sum')

            if self.aggr == 'softmax_sum':
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out
        else:
            raise NotImplementedError('To be implemented')
Ejemplo n.º 2
0
    def aggregate(self, inputs: Tensor, index: Tensor,
                  dim_size: Optional[int] = None) -> Tensor:

        if self.aggr == 'softmax':
            out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
            return scatter(inputs * out, index, dim=self.node_dim,
                           dim_size=dim_size, reduce='sum')

        elif self.aggr == 'softmax_sg':
            out = scatter_softmax(inputs * self.t, index,
                                  dim=self.node_dim).detach()
            return scatter(inputs * out, index, dim=self.node_dim,
                           dim_size=dim_size, reduce='sum')

        elif self.aggr == 'power':
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
                          dim_size=dim_size, reduce='mean')
            torch.clamp_(out, min_value, max_value)
            return torch.pow(out, 1 / self.p)

        else:
            return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                           reduce=self.aggr)
Ejemplo n.º 3
0
    def aggregate(self,
                  inputs: Tensor,
                  index: Tensor,
                  dim_size: Optional[int] = None) -> Tensor:

        if self.aggr == 'softmax':
            out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
            return scatter(inputs * out,
                           index,
                           dim=self.node_dim,
                           dim_size=dim_size,
                           reduce='sum')

        elif self.aggr == 'softmax_sg':
            out = scatter_softmax(inputs * self.t, index,
                                  dim=self.node_dim).detach()
            return scatter(inputs * out,
                           index,
                           dim=self.node_dim,
                           dim_size=dim_size,
                           reduce='sum')
        elif self.aggr == 'stat':
            _mean = scatter_mean(inputs,
                                 index,
                                 dim=self.node_dim,
                                 dim_size=dim_size)
            _std = scatter_std(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size).detach()
            _min = scatter_min(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size)[0]
            _max = scatter_max(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size)[0]

            _mean = _mean.unsqueeze(dim=-1)
            _std = _std.unsqueeze(dim=-1)
            _min = _min.unsqueeze(dim=-1)
            _max = _max.unsqueeze(dim=-1)

            stat = torch.cat([_mean, _std, _min, _max], dim=-1)
            stat = self.lin_stat(stat)
            stat = stat.squeeze(dim=-1)
            return stat

        else:
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(torch.pow(inputs, self.p),
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='mean')
            torch.clamp_(out, min_value, max_value)
            return torch.pow(out, 1 / self.p)
Ejemplo n.º 4
0
    def aggregate(self, inputs, index, ptr=None, dim_size=None):

        if self.aggr in ["add", "mean", "max", None]:
            return super(GenMessagePassing,
                         self).aggregate(inputs, index, ptr, dim_size)

        elif self.aggr in ["softmax_sg", "softmax", "softmax_sum"]:

            if self.learn_t:
                out = scatter_softmax(inputs * self.t,
                                      index,
                                      dim=self.node_dim)
            else:
                with torch.no_grad():
                    out = scatter_softmax(inputs * self.t,
                                          index,
                                          dim=self.node_dim)

            out = scatter(inputs * out,
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce="sum")

            if self.aggr == "softmax_sum":
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out

        elif self.aggr in ["power", "power_sum"]:
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(
                torch.pow(inputs, self.p),
                index,
                dim=self.node_dim,
                dim_size=dim_size,
                reduce="mean",
            )
            torch.clamp_(out, min_value, max_value)
            out = torch.pow(out, 1 / self.p)
            # torch.clamp(out, min_value, max_value)

            if self.aggr == "power_sum":
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out

        else:
            raise NotImplementedError("To be implemented")
        def forward(self, data):
            x, batch = data.x.float(), data.batch

            graph_ids, graph_node_counts = batch.unique(return_counts=True)

            CoC, CoC_edge_attr = self.return_CoC_and_edge_attr(x, batch)

            x = torch.cat([
                x,
                x_feature_constructor(x, graph_node_counts), CoC_edge_attr,
                CoC[batch]
            ],
                          dim=1)
            beta = graph_node_counts[batch].view(-1, 1) / 100

            CoC = torch.cat([
                CoC,
                scatter_distribution(scatter_softmax(
                    self.betas[0](beta) * self.atts[0](x), batch, dim=0) * x,
                                     batch,
                                     dim=0)
            ],
                            dim=1)

            x = self.x_encoder(x)
            CoC = self.act(self.CoC_encoder(CoC))

            h = torch.zeros((x.shape[0], self.hcs)).type_as(x)
            for i in range(N_metalayers):
                h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x],
                                                        dim=1), h))
                CoC = self.act(self.CoC_batch_norm[i](
                    self.lins_CoC_msg[i](scatter_distribution(scatter_softmax(
                        self.betas[i + 1](beta) * self.atts[i + 1](h),
                        batch,
                        dim=0) * h,
                                                              batch,
                                                              dim=0)) +
                    self.lins_CoC_self[i](CoC)))
                h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x],
                                                        dim=1), h))
                x = self.act(self.x_batch_norm[i](self.lins_x_msg[i](h) +
                                                  self.lins_x_self[i](x)))

            CoC = torch.cat([
                CoC,
                scatter_distribution(scatter_softmax(
                    self.betas[-1](beta) * self.atts[-1](x), batch, dim=0) * x,
                                     batch,
                                     dim=0)
            ],
                            dim=1)

            return self.decoder(CoC)
Ejemplo n.º 6
0
    def forward(self, x, batch, bsize=None):
        r"""Args:
            x (Tensor): Node feature matrix
                :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
            batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
                B-1\}}^N`, which assigns each node to a specific example.
            size (int, optional): Batch-size :math:`B`.
                Automatically calculated if not given. (default: :obj:`None`)
        :rtype: :class:`Tensor`
        """
        bsize = int(batch.max().item() + 1) if bsize is None else bsize
        n_nodes = scatter_add(torch.ones_like(x), batch, dim=0, dim_size=bsize)
        if self.family == "softmax":
            out = scatter_softmax(self.p * x.detach(), batch, dim=0)
            return scatter_add(x * out,
                                batch, dim=0, dim_size=bsize)*n_nodes / (1+self.beta*(n_nodes-1))

        elif self.family == "power":
            # numerical stability - avoid powers of large numbers or negative ones
            min_x, max_x = 1e-7, 1e+3
            torch.clamp_(x, min_x, max_x)
            out = scatter_add(torch.pow(x, self.p),
                               batch, dim=0, dim_size=bsize) / (1+self.beta*(n_nodes-1))
            torch.clamp_(out, min_x, max_x)
            return torch.pow(out, 1 / self.p)
Ejemplo n.º 7
0
 def forward(self, q, k, index=None, size=None):
     batchlize = q.dim() > 2
     if batchlize:
         attn_score = torch.einsum('bij, bij->bi', q, k)
     else:
         attn_score = torch.einsum('ij, ij->i', q, k)
     attn_score = attn_score / self.temperature
     attn_score = self.dropout(scatter_softmax(attn_score, index, dim=-1))
     return attn_score
Ejemplo n.º 8
0
 def forward(self, q, k, index=None, size=None):
     batchlize = True if q.dim() > 3 else False
     attn = self.attn
     if batchlize:
         attn = self.attn.unsqueeze(0)
     inp = torch.cat([q, k], dim=-1)
     attn_score = F.leaky_relu((inp * attn).sum(dim=-1),
                               self.negative_slope)
     attn_score = (inp * attn).sum(dim=-1)
     attn_score = self.dropout(scatter_softmax(attn_score, index, dim=1))
     return attn_score
Ejemplo n.º 9
0
    def aggregate(self, inputs, index, ptr=None, dim_size=None):

        if self.aggr in ['add', 'mean', 'max', None]:
            return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size)

        elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']:

            if self.learn_t:
                out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
            else:
                with torch.no_grad():
                    out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)

            out = scatter(inputs*out, index, dim=self.node_dim,
                          dim_size=dim_size, reduce='sum')

            if self.aggr == 'softmax_sum':
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out


        elif self.aggr in ['power', 'power_sum']:
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
                          dim_size=dim_size, reduce='mean')
            torch.clamp_(out, min_value, max_value)
            out = torch.pow(out, 1/self.p)

            if self.aggr == 'power_sum':
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out

        else:
            raise NotImplementedError('To be implemented')
Ejemplo n.º 10
0
    def aggregate(self, inputs, index, ptr=None, dim_size=None):

        if self.aggr in ['add', 'mean', 'max', None]:
            return super(GenMessagePassing,
                         self).aggregate(inputs, index, ptr, dim_size)

        elif self.aggr == 'softmax':
            out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
            out = scatter(inputs * out,
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='sum')
            return out

        elif self.aggr == 'softmax_sg':
            with torch.no_grad():
                out = scatter_softmax(inputs * self.t,
                                      index,
                                      dim=self.node_dim)
            out = scatter(inputs * out,
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='sum')
            return out

        elif self.aggr == 'power':
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(torch.pow(inputs, self.p),
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='mean')
            torch.clamp_(out, min_value, max_value)
            return torch.pow(out, 1 / self.p)

        else:
            raise NotImplementedError('To be implemented')
Ejemplo n.º 11
0
 def object_selection(self, features, n_obj):
     index = torch.LongTensor(
         sum([[i] * n for i, n in enumerate(n_obj)], []))
     selector = features[:, :1] * self.scale_selector
     selector = scatter_softmax(selector, index.to(self.device), dim=0)
     indptr = torch.LongTensor([0] + np.cumsum(n_obj).tolist()).to(
         self.device)
     selected_features = segment_csr(selector * features,
                                     indptr,
                                     reduce="sum")
     image = selected_features.narrow(1, 1, self.nc_out)
     depth = None
     if self.depth:
         depth = selected_features.narrow(1, 0, 1)
     return image, depth
Ejemplo n.º 12
0
    def aggregate(self, inputs, index, ptr=None, dim_size=None):

        if self.aggr in ['add', 'mean', 'max', None]:
            return super(GenMessagePassing,
                         self).aggregate(inputs, index, ptr, dim_size)

        elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']:

            # NOTE: pruning adj
            if self.edge_mask is not None:
                inputs = inputs * self.edge_mask

            if self.learn_t:
                out = scatter_softmax(inputs * self.t,
                                      index,
                                      dim=self.node_dim)
            else:
                with torch.no_grad():
                    out = scatter_softmax(inputs * self.t,
                                          index,
                                          dim=self.node_dim)

            out = scatter(inputs * out,
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='sum')

            if self.aggr == 'softmax_sum':
                self.sigmoid_y = torch.sigmoid(self.y)
                degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
                out = torch.pow(degrees, self.sigmoid_y) * out

            return out
        else:
            assert False
Ejemplo n.º 13
0
    def aggregate(self,
                  inputs,
                  index: Tensor,
                  dim_size: Optional[int] = None) -> Tensor:

        inputs, duplicates_idx = inputs
        index = torch.cat((index, index[duplicates_idx, ]), dim=0)
        #         print('duplicates_idx: ',duplicates_idx)

        if self.aggr == 'softmax':
            out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
            return scatter(inputs * out,
                           index,
                           dim=self.node_dim,
                           dim_size=dim_size,
                           reduce='sum')

        elif self.aggr == 'softmax_sg':
            out = scatter_softmax(inputs * self.t, index,
                                  dim=self.node_dim).detach()
            return scatter(inputs * out,
                           index,
                           dim=self.node_dim,
                           dim_size=dim_size,
                           reduce='sum')

        else:
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(torch.pow(inputs, self.p),
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='mean')
            torch.clamp_(out, min_value, max_value)
            return torch.pow(out, 1 / self.p)
Ejemplo n.º 14
0
    def forward(self,
                x,
                edge_index,
                edge_weight: OptTensor = None,
                batch: OptTensor = None,
                lambda_max: OptTensor = None):
        """"""
        if self.normalization != 'sym' and lambda_max is None:
            raise ValueError('You need to pass `lambda_max` to `forward() in`'
                             'case the normalization is non-symmetric.')

        num_nodes = x.shape[0]
        query = self.query(x)
        key = self.key(x)
        x = self.fc(x)

        Tx_0 = x
        out = [Tx_0]

        # propagate_type: (x: Tensor, norm: Tensor)
        khop_edge_index = k_hop_edges(edge_index, num_nodes, self.K)
        for k in range(1, self.K):
            edge_index = khop_edge_index[k - 1]
            row, col = edge_index[0], edge_index[1]
            score = torch.sum(query[row] * key[col], dim=1)
            normed_score = scatter_softmax(score, col)
            mask = normed_score > 1e-5
            Tx_2 = scatter_sum(x[row[mask]] * normed_score[mask].view(-1, 1),
                               col[mask],
                               dim=0)
            out.append(Tx_2)

        out = torch.stack(out)

        ## for importance score
        node_score = self.node_att(
            out.permute(1, 0, 2).contiguous().view(-1,
                                                   self.K * self.out_channels))

        out = torch.sum(self.att_w.view(self.K, 1, self.out_channels) * out,
                        dim=0) + self.att_bias.view(1, self.out_channels)
        if self.decoupled:
            out = out / torch.norm(out, dim=1, keepdim=True)
            out = out * node_score

        return out, node_score
Ejemplo n.º 15
0
def test_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])

    out = scatter_softmax(src, index)

    out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
    ],
                           dim=0)

    assert torch.allclose(out, expected)
def _post_process(score, flatten, method, **kwargs):
    if method == 'sigmoid':
        score = torch.sigmoid(score)
    elif method == 'exp':
        score = torch.exp(score.clamp(max=70))
    elif method == 'softmax':
        mask = kwargs['mask'] if 'mask' in kwargs else None
        index = kwargs['index'] if 'index' in kwargs else None
        dim = kwargs['dim'] if 'dim' in kwargs else -1
        if flatten:
            assert index is not None
            score = scatter_softmax(score, index, dim=dim)
        else:
            score = masked_softmax(score, mask=mask, dim=dim)
    else:
        assert method == 'origin'
    return score
Ejemplo n.º 17
0
    def __get_code_vectors(self, convoluted_node_embedding, tree_indices):
        '''
        produces code vectors for a batch of trees, each code vector is of dimension dim
        convoluted_node_embedding: shape is of n * dim
        tree_indices: shape is of n
        output should be of shape T * dim
        '''
        # calculate weight (apha_i) for each convoluted tree node embedding
        intermediate_result = torch.matmul(convoluted_node_embedding, self.alpha.unsqueeze(1)) # n * 1
        alpha_i = scatter_softmax(intermediate_result.squeeze(), tree_indices).unsqueeze(1) # n * 1

        # multiply the weight to each convoluted tree node embedding n * dim
        res = torch.bmm(convoluted_node_embedding.unsqueeze(2), alpha_i.unsqueeze(2)).squeeze()

        # scatter add to produce code vectors, T * dim
        code_vectors = scatter_add(res, tree_indices, dim = 0)

        return code_vectors
Ejemplo n.º 18
0
def scatter_softmax(device: Type[draw_devices],
                    token_sizes: Type[draw_token_sizes],
                    dim: Type[draw_embedding_dims], *, timer: TimerSuit):
    device = device()
    token_size, num = token_sizes(), token_sizes()
    if num > token_size:
        token_size, num = num, token_size
    in_dim = dim()

    inputs = torch.randn((token_size, in_dim),
                         requires_grad=True,
                         device=device)
    index1 = torch.randint(0, num, (token_size, ), device=device)
    index2 = index1[:, None].expand_as(inputs)

    with timer.rua_forward:
        actual = rua.scatter_softmax(tensor=inputs, index=index1)

    with timer.naive_forward:
        excepted = torch_scatter.scatter_softmax(src=inputs,
                                                 index=index2,
                                                 dim=0)

    with timer.rua_backward:
        torch.autograd.grad(
            actual,
            inputs,
            torch.ones_like(actual),
            create_graph=False,
        )

    with timer.naive_backward:
        torch.autograd.grad(
            excepted,
            inputs,
            torch.ones_like(excepted),
            create_graph=False,
        )
Ejemplo n.º 19
0
def test_softmax():
    src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
    src.requires_grad_()
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])

    out = scatter_softmax(src, index)

    out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1)
    out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1)
    out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1)

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
    ],
                           dim=0)

    assert torch.allclose(out, expected)

    out.backward(torch.randn_like(out))

    jit = torch.jit.script(scatter_softmax)
    assert jit(src, index).tolist() == out.tolist()
Ejemplo n.º 20
0
        def forward(self, data):
            x, edge_attr, edge_index, batch = data.x, data.edge_attr, data.edge_index, data.batch
            ###############
            x = x.float()
            ###############
            pos = x[:, -3:]

            graph_ids, graph_node_counts = batch.unique(return_counts=True)

            time_edge_index = time_edge_indeces(x[:, 1], batch)

            edge_attr = edge_feature_constructor(x, time_edge_index)

            # Define central nodes at Center of Charge:
            CoC = scatter_sum(pos * x[:, 0].view(-1, 1), batch,
                              dim=0) / scatter_sum(
                                  x[:, 0].view(-1, 1), batch, dim=0)

            # Define edge_attr for those edges:
            cart = pos[:, -3:] - CoC[batch, :3]
            del pos
            rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)
            rho_mask = rho.squeeze() != 0
            cart[rho_mask] = cart[rho_mask] / rho[rho_mask]
            CoC_edge_attr = torch.cat([cart.type_as(x), rho.type_as(x)], dim=1)

            x = torch.cat([
                x,
                x_feature_constructor(x, graph_node_counts), edge_attr,
                x[time_edge_index[0]], CoC_edge_attr
            ],
                          dim=1)

            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            att = scatter_softmax(graph_node_counts[batch].view(-1, 1) / 100 *
                                  self.att(x),
                                  batch,
                                  dim=0)
            att_d = scatter_distribution(att, batch, dim=0)
            mask = att.squeeze() > att_d[batch, 0] + att_d[batch, 1]

            x = x[mask]
            batch = batch[mask]

            x = self.act(self.x_encoder(x))
            CoC = self.act(self.CoC_encoder(CoC))

            h = torch.zeros((x.shape[0], self.hcs)).type_as(x)

            for i in range(N_metalayers):
                h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x],
                                                        dim=1), h))

                CoC = self.act(
                    self.CoC_batch_norm[i](self.lins_CoC_msg[i](
                        scatter_distribution(h, batch, dim=0)) +
                                           self.lins_CoC_self[i](CoC)))

                h = self.act(self.GRUCells[i](torch.cat([CoC[batch], x],
                                                        dim=1), h))

                x = self.act(self.x_batch_norm[i](self.lins_x_msg[i](h) +
                                                  self.lins_x_self[i](x)))

            CoC = torch.cat([CoC, scatter_distribution(x, batch, dim=0)],
                            dim=1)

            for batch_norm, lin in zip(self.decoder_batch_norms,
                                       self.decoders):
                CoC = self.act(batch_norm(lin(CoC)))

            CoC = self.decoder(CoC)
            return CoC
 def forward(self, x, batch, CoC):
     att = scatter_softmax(self.att_lin(
         torch.cat([x, CoC[batch]], dim=1)),
                           batch,
                           dim=0)
     return scatter_sum(att * x, batch, dim=0)