Ejemplo n.º 1
0
    def mulmul(self, Theta_t):
        Theta = scatter_mul(Theta_t, index=self.tar_nodes) # [N]
        Theta = Theta[self.src_nodes] #[E]
        Theta_cav = scatter_mul(Theta_t, index=self.cave_index)[:self.E]

        mul = Theta / Theta_cav
        return mul
Ejemplo n.º 2
0
    def update_psi_h_fast(self, write_nodes_slice_tensor,
                          read_messages_slice_tensor, index_tensor):
        if read_messages_slice_tensor.sum() == 0:  # isolated node
            return
        # sum all messages
        src = 1 + (self.message_map[read_messages_slice_tensor].clone()
                   if not self.disable_gradient else
                   self.message_map[read_messages_slice_tensor]) * (
                       torch.exp(self.beta *
                                 (self.w_indexed[read_messages_slice_tensor]
                                  if self.is_weighted else 1)) - 1)
        out = self.marginal_psi.new_ones((self.num_nodes, self.num_groups))
        out = scatter_mul(src, index_tensor, out=out, dim=0)
        out = out[write_nodes_slice_tensor]
        out = out * torch.exp(self.h)
        out = out / out.sum(-1).reshape(-1, 1)

        # subtract the old psi
        self.h -= -self.beta * self.mean_w * (
            self.marginal_psi[write_nodes_slice_tensor].clone()
            if not self.disable_gradient else
            self.marginal_psi[write_nodes_slice_tensor]).sum(0)
        # update marginal_psi
        self.marginal_psi[write_nodes_slice_tensor] = out
        # add the new psi
        self.h += -self.beta * self.mean_w * (
            self.marginal_psi[write_nodes_slice_tensor].clone()
            if not self.disable_gradient else
            self.marginal_psi[write_nodes_slice_tensor]).sum(0)
Ejemplo n.º 3
0
    def forward(self, data, state=None, mode='train/test'):
        #  x := Node embeddings (1-D flattened array with nodes of all graphs in batch)
        #  edge_index := adjacency list
        #  batch := 1-D array containing the graph-id of each node
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.graph_model(data)

        graph_pooled = global_mean_pool(x, batch)
        processed = self.lin2(torch.cat([graph_pooled, state], dim=1))
        x_cat = torch.cat([x, processed[batch]], dim=1)
        scores_start = F.leaky_relu(self.lin_start(x_cat),
                                    negative_slope=self.neg_slope).view(-1)
        scores_end = F.leaky_relu(self.lin_end(x_cat),
                                  negative_slope=self.neg_slope).view(-1)

        attn_start, log_attn_start = softmax(scores_start, batch)
        attn_end, log_attn_end = softmax(scores_end, batch)
        state = F.leaky_relu(self.lin_state(
            torch.cat([processed, state], dim=1)),
                             negative_slope=self.neg_slope)

        if mode == 'train/test':
            labels_start = data.y[:, 0]
            labels_end = data.y[:, 1]
            loss_start = -(log_attn_start *
                           labels_start.type(torch.float)).mean()
            loss_end = -(log_attn_end * labels_end.type(torch.float)).mean()
            loss = loss_start + loss_end

            max_scores_start = scatter_max(scores_start, batch)[0][batch]
            max_scores_end = scatter_max(scores_end, batch)[0][batch]
            selected_start = max_scores_start.eq(scores_start).type(torch.long)
            selected_end = max_scores_end.eq(scores_end).type(torch.long)
            correct_start = scatter_mul(
                selected_start.eq(labels_start).type(torch.float), batch)
            correct_end = scatter_mul(
                selected_end.eq(labels_end).type(torch.float), batch)
            correct = (correct_start * correct_end)

            return loss, correct, state

        elif mode == 'infer':
            return attn_start, attn_end, state
Ejemplo n.º 4
0
    def update_message_fast(self, write_messages_slice_tensor,
                            read_messages_slice_tensor, index_tensor):
        if read_messages_slice_tensor.sum() == 0:  # isolated node
            return 0
        # sum all messages
        src = 1 + (self.message_map[read_messages_slice_tensor].clone()
                   if not self.disable_gradient else
                   self.message_map[read_messages_slice_tensor]) * (
                       torch.exp(self.beta *
                                 (self.w_indexed[read_messages_slice_tensor]
                                  if self.is_weighted else 1)) - 1)
        out = self.message_map.new_ones((self.num_messages, self.num_groups))
        out = scatter_mul(src, index_tensor, out=out, dim=0)
        out = out[write_messages_slice_tensor]
        out = out * torch.exp(self.h)
        out = out / out.sum(-1).reshape(-1, 1)

        max_diff = (out.detach() -
                    self.message_map[write_messages_slice_tensor].detach()
                    ).abs().max()
        # update messages
        self.message_map[write_messages_slice_tensor] = out
        return max_diff
Ejemplo n.º 5
0
def scatter_mul(device: Type[draw_devices],
                token_sizes: Type[draw_token_sizes],
                dim: Type[draw_embedding_dims], *, timer: TimerSuit):
    device = device()
    token_size, num = token_sizes(), token_sizes()
    if num > token_size:
        token_size, num = num, token_size
    in_dim = dim()

    inputs = torch.randn((token_size, in_dim),
                         requires_grad=True,
                         device=device)
    index1 = torch.randint(0, num, (token_size, ), device=device)
    index2 = index1[:, None].expand_as(inputs)

    with timer.rua_forward:
        actual = rua.scatter_mul(tensor=inputs, index=index1)

    with timer.naive_forward:
        excepted = torch_scatter.scatter_mul(src=inputs, index=index2, dim=0)

    with timer.rua_backward:
        torch.autograd.grad(
            actual,
            inputs,
            torch.ones_like(actual),
            create_graph=False,
        )

    with timer.naive_backward:
        torch.autograd.grad(
            excepted,
            inputs,
            torch.ones_like(excepted),
            create_graph=False,
        )
Ejemplo n.º 6
0
    def theta_aggr(self):
        theta = scatter_mul(self.Theta_t, index=self.tar_nodes)

        return theta, self.Ps_i
Ejemplo n.º 7
0
 def influence(self):
     # Ps_i : the probability of node i being S 
     self.Ps_i = self.Ps_i_0 * scatter_mul(self.Theta_t, index=self.tar_nodes)
     return T.sum(1-self.Ps_i)
Ejemplo n.º 8
0
src = torch.Tensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float()
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))

out = scatter_div(src, index, out=out)

print(out)
# tensor([[1.0000, 1.0000, 0.2500, 0.5000, 0.5000, 1.0000],
# [0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]])

# 最大最小平均值
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index)
print(out, argmax)

out, argmin = scatter_min(src, index)
print(out, argmin)

out = scatter_mean(src, index)
print(out)

out = scatter_mul(src, index)
print(out)

out = scatter_std(src, index)
print(out)

out = scatter_sub(src, index)
print(out)