Exemple #1
0
    def recon_loss(self,
                   z,
                   batch,
                   negsampling=True,
                   nodespergraph=NODESPERGRAPH,
                   negfactor=20,
                   cdf_tol=None):
        r"""Given latent variables :obj:`z`, computes the binary cross
        entropy loss for positive edges :obj:`pos_edge_index` and negative
        sampled edges.
        Args:
            x (Tensor): The input space :math:`\mathbf{Z}`.
            z (Tensor): The latent space :math:`\mathbf{Z}`.
            pos_edge_index (LongTensor): The positive edges to train against.
            neg_edge_index (LongTensor, optional): The negative edges to train
                against. If not given, uses negative sampling to calculate
                negative edges. (default: :obj:`None`)
        """

        pos_edge_loss = self.edge_decoder(z,
                                          batch.edge_index,
                                          batch.edge_attr,
                                          cdf_tol=cdf_tol).mean()

        if negsampling:
            # Do not include self-loops in negative samples
            pos_edge_index, _ = remove_self_loops(batch.edge_index)
            pos_edge_index, _ = add_self_loops(batch.edge_index)

            negsamples = batched_negative_sampling(
                batch.edge_index,
                batch.batch,
                num_neg_samples=int((nodespergraph**2 / negfactor))).cuda()

            neg_edge_loss = self.edge_decoder(
                z, negsamples,
                torch.full((negsamples.size(1), ), 0.0).cuda(),
                cdf_tol).mean()
        else:
            neg_edge_loss = 0.0

        if self.node_decoder is not None:
            x = batch.x
            batch.x = z
            pred = self.node_decoder(batch)
            criterion = L1Loss()
            node_loss = criterion(pred, x)
        else:
            node_loss = 0.0

        edge_loss = pos_edge_loss + neg_edge_loss
        return edge_loss, node_loss
def test_batched_negative_sampling():
    edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]])
    edge_index = torch.cat([edge_index, edge_index + 4], dim=1)
    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])

    neg_edge_index = batched_negative_sampling(edge_index, batch)
    assert neg_edge_index.size(1) == edge_index.size(1)

    adj = torch.zeros(8, 8, dtype=torch.bool)
    adj[edge_index[0], edge_index[1]] = 1

    neg_adj = torch.zeros(8, 8, dtype=torch.bool)
    neg_adj[neg_edge_index[0], neg_edge_index[1]] = 1
    assert (adj & neg_adj).sum() == 0
    assert neg_adj[:4, 4:].sum() == 0
    assert neg_adj[4:, :4].sum() == 0
    def negative_sampling(self, edge_index: Tensor, num_nodes: int,
                          batch: OptTensor = None) -> Tensor:

        num_neg_samples = int(self.neg_sample_ratio * self.edge_sample_ratio *
                              edge_index.size(1))

        if not self.is_undirected and not is_undirected(
                edge_index, num_nodes=num_nodes):
            edge_index = to_undirected(edge_index, num_nodes=num_nodes)

        if batch is None:
            neg_edge_index = negative_sampling(edge_index, num_nodes,
                                               num_neg_samples=num_neg_samples)
        else:
            neg_edge_index = batched_negative_sampling(
                edge_index, batch, num_neg_samples=num_neg_samples)

        return neg_edge_index
Exemple #4
0
def test_bipartite_batched_negative_sampling():
    edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]])
    edge_index2 = edge_index1 + torch.tensor([[2], [4]])
    edge_index3 = edge_index2 + torch.tensor([[2], [4]])
    edge_index = torch.cat([edge_index1, edge_index2, edge_index3], dim=1)
    src_batch = torch.tensor([0, 0, 1, 1, 2, 2])
    dst_batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])

    neg_edge_index = batched_negative_sampling(edge_index,
                                               (src_batch, dst_batch))
    assert neg_edge_index.size(1) <= edge_index.size(1)

    adj = torch.zeros(6, 12, dtype=torch.bool)
    adj[edge_index[0], edge_index[1]] = True
    neg_adj = torch.zeros(6, 12, dtype=torch.bool)
    neg_adj[neg_edge_index[0], neg_edge_index[1]] = True

    assert (adj & neg_adj).sum() == 0
    assert (adj | neg_adj).sum() == edge_index.size(1) + neg_edge_index.size(1)
Exemple #5
0
    def forward_to_reconstruct_edges(self,
                                     x,
                                     edge_index,
                                     r_scaling_1,
                                     r_bias_1,
                                     r_scaling_2,
                                     r_bias_2,
                                     batch=None):
        """
        :param x: [N, F]
        :param edge_index: [2, E]
        :param r_scaling_1: [1]
        :param r_scaling_2: [1]
        :param r_bias_1: [1]
        :param r_bias_2: [1]
        :param batch: [N]
        :return: Reconstructed edges [2, E + neg_E] (0 <= v <=  1)
        """

        if batch is None:
            num_neg_samples = int(self.neg_sample_ratio * edge_index.size(1))
            neg_edge_index = negative_sampling(edge_index=edge_index,
                                               num_nodes=x.size(0),
                                               num_neg_samples=num_neg_samples)
        else:
            neg_edge_index = batched_negative_sampling(edge_index=edge_index,
                                                       batch=batch)

        total_edge_index = torch.cat([edge_index, neg_edge_index],
                                     dim=-1)  # [2, E + neg_E]
        total_edge_index_j, total_edge_index_i = total_edge_index  # [E + neg_E]
        x_i = torch.index_select(x, 0, total_edge_index_i)  # [E + neg_E, F]
        x_j = torch.index_select(x, 0, total_edge_index_j)  # [E + neg_E, F]

        recon = torch.einsum("ef,ef->e", x_i, x_j)  # [E + neg_E]
        recon = r_scaling_1 * F.elu(recon) + r_bias_1
        recon = r_scaling_2 * F.elu(recon) + r_bias_2
        return recon
Exemple #6
0
    def forward(self,
                x,
                edge_index,
                size=None,
                batch=None,
                neg_edge_index=None,
                attention_edge_index=None):
        """
        :param x: [N, F]
        :param edge_index: [2, E]
        :param size:
        :param batch: None or [B]
        :param neg_edge_index: When using explicitly given negative edges.
        :param attention_edge_index: [2, E'], Use for link prediction
        :return:
        """
        if self.pretraining and self.pretraining_noise_ratio > 0.0:
            edge_index, _ = dropout_adj(
                edge_index,
                p=self.pretraining_noise_ratio,
                force_undirected=is_undirected(edge_index),
                num_nodes=x.size(0),
                training=self.training)

        if size is None and torch.is_tensor(x):
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # [N, F0] * [F0, heads * F] = [N, heads * F]
        if torch.is_tensor(x):
            x = torch.matmul(x, self.weight)
        else:
            x = (None if x[0] is None else torch.matmul(x[0], self.weight),
                 None if x[1] is None else torch.matmul(x[1], self.weight))

        propagated = self.propagate(edge_index, size=size, x=x)

        if (self.is_super_gat
                and self.training) or (attention_edge_index
                                       is not None) or (neg_edge_index
                                                        is not None):

            device = next(self.parameters()).device
            num_pos_samples = int(self.edge_sample_ratio * edge_index.size(1))
            num_neg_samples = int(self.neg_sample_ratio *
                                  self.edge_sample_ratio * edge_index.size(1))

            if attention_edge_index is not None:
                neg_edge_index = None

            elif neg_edge_index is not None:
                pass

            elif batch is None:
                if self.to_undirected_at_neg:
                    edge_index_for_ns = to_undirected(edge_index,
                                                      num_nodes=x.size(0))
                else:
                    edge_index_for_ns = edge_index
                neg_edge_index = negative_sampling(
                    edge_index=edge_index_for_ns,
                    num_nodes=x.size(0),
                    num_neg_samples=num_neg_samples,
                )
            else:
                neg_edge_index = batched_negative_sampling(
                    edge_index=edge_index,
                    batch=batch,
                    num_neg_samples=num_neg_samples,
                )

            if self.edge_sample_ratio < 1.0:
                pos_indices = random.sample(range(edge_index.size(1)),
                                            num_pos_samples)
                pos_indices = torch.tensor(pos_indices).long().to(device)
                pos_edge_index = edge_index[:, pos_indices]
            else:
                pos_edge_index = edge_index

            att_with_negatives = self._get_attention_with_negatives(
                x=x,
                edge_index=pos_edge_index,
                neg_edge_index=neg_edge_index,
                total_edge_index=attention_edge_index,
            )  # [E + neg_E, heads]

            # Labels
            if self.training and (self.cache["att_label"] is None
                                  or not self.cache_label):
                att_label = torch.zeros(
                    att_with_negatives.size(0)).float().to(device)
                att_label[:pos_edge_index.size(1)] = 1.
            elif self.training and self.cache["att_label"] is not None:
                att_label = self.cache["att_label"]
            else:
                att_label = None
            self._update_cache("att_label", att_label)
            self._update_cache("att_with_negatives", att_with_negatives)

        return propagated