def recon_loss(self, z, batch, negsampling=True, nodespergraph=NODESPERGRAPH, negfactor=20, cdf_tol=None): r"""Given latent variables :obj:`z`, computes the binary cross entropy loss for positive edges :obj:`pos_edge_index` and negative sampled edges. Args: x (Tensor): The input space :math:`\mathbf{Z}`. z (Tensor): The latent space :math:`\mathbf{Z}`. pos_edge_index (LongTensor): The positive edges to train against. neg_edge_index (LongTensor, optional): The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: :obj:`None`) """ pos_edge_loss = self.edge_decoder(z, batch.edge_index, batch.edge_attr, cdf_tol=cdf_tol).mean() if negsampling: # Do not include self-loops in negative samples pos_edge_index, _ = remove_self_loops(batch.edge_index) pos_edge_index, _ = add_self_loops(batch.edge_index) negsamples = batched_negative_sampling( batch.edge_index, batch.batch, num_neg_samples=int((nodespergraph**2 / negfactor))).cuda() neg_edge_loss = self.edge_decoder( z, negsamples, torch.full((negsamples.size(1), ), 0.0).cuda(), cdf_tol).mean() else: neg_edge_loss = 0.0 if self.node_decoder is not None: x = batch.x batch.x = z pred = self.node_decoder(batch) criterion = L1Loss() node_loss = criterion(pred, x) else: node_loss = 0.0 edge_loss = pos_edge_loss + neg_edge_loss return edge_loss, node_loss
def test_batched_negative_sampling(): edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) edge_index = torch.cat([edge_index, edge_index + 4], dim=1) batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) neg_edge_index = batched_negative_sampling(edge_index, batch) assert neg_edge_index.size(1) == edge_index.size(1) adj = torch.zeros(8, 8, dtype=torch.bool) adj[edge_index[0], edge_index[1]] = 1 neg_adj = torch.zeros(8, 8, dtype=torch.bool) neg_adj[neg_edge_index[0], neg_edge_index[1]] = 1 assert (adj & neg_adj).sum() == 0 assert neg_adj[:4, 4:].sum() == 0 assert neg_adj[4:, :4].sum() == 0
def negative_sampling(self, edge_index: Tensor, num_nodes: int, batch: OptTensor = None) -> Tensor: num_neg_samples = int(self.neg_sample_ratio * self.edge_sample_ratio * edge_index.size(1)) if not self.is_undirected and not is_undirected( edge_index, num_nodes=num_nodes): edge_index = to_undirected(edge_index, num_nodes=num_nodes) if batch is None: neg_edge_index = negative_sampling(edge_index, num_nodes, num_neg_samples=num_neg_samples) else: neg_edge_index = batched_negative_sampling( edge_index, batch, num_neg_samples=num_neg_samples) return neg_edge_index
def test_bipartite_batched_negative_sampling(): edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]]) edge_index2 = edge_index1 + torch.tensor([[2], [4]]) edge_index3 = edge_index2 + torch.tensor([[2], [4]]) edge_index = torch.cat([edge_index1, edge_index2, edge_index3], dim=1) src_batch = torch.tensor([0, 0, 1, 1, 2, 2]) dst_batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) neg_edge_index = batched_negative_sampling(edge_index, (src_batch, dst_batch)) assert neg_edge_index.size(1) <= edge_index.size(1) adj = torch.zeros(6, 12, dtype=torch.bool) adj[edge_index[0], edge_index[1]] = True neg_adj = torch.zeros(6, 12, dtype=torch.bool) neg_adj[neg_edge_index[0], neg_edge_index[1]] = True assert (adj & neg_adj).sum() == 0 assert (adj | neg_adj).sum() == edge_index.size(1) + neg_edge_index.size(1)
def forward_to_reconstruct_edges(self, x, edge_index, r_scaling_1, r_bias_1, r_scaling_2, r_bias_2, batch=None): """ :param x: [N, F] :param edge_index: [2, E] :param r_scaling_1: [1] :param r_scaling_2: [1] :param r_bias_1: [1] :param r_bias_2: [1] :param batch: [N] :return: Reconstructed edges [2, E + neg_E] (0 <= v <= 1) """ if batch is None: num_neg_samples = int(self.neg_sample_ratio * edge_index.size(1)) neg_edge_index = negative_sampling(edge_index=edge_index, num_nodes=x.size(0), num_neg_samples=num_neg_samples) else: neg_edge_index = batched_negative_sampling(edge_index=edge_index, batch=batch) total_edge_index = torch.cat([edge_index, neg_edge_index], dim=-1) # [2, E + neg_E] total_edge_index_j, total_edge_index_i = total_edge_index # [E + neg_E] x_i = torch.index_select(x, 0, total_edge_index_i) # [E + neg_E, F] x_j = torch.index_select(x, 0, total_edge_index_j) # [E + neg_E, F] recon = torch.einsum("ef,ef->e", x_i, x_j) # [E + neg_E] recon = r_scaling_1 * F.elu(recon) + r_bias_1 recon = r_scaling_2 * F.elu(recon) + r_bias_2 return recon
def forward(self, x, edge_index, size=None, batch=None, neg_edge_index=None, attention_edge_index=None): """ :param x: [N, F] :param edge_index: [2, E] :param size: :param batch: None or [B] :param neg_edge_index: When using explicitly given negative edges. :param attention_edge_index: [2, E'], Use for link prediction :return: """ if self.pretraining and self.pretraining_noise_ratio > 0.0: edge_index, _ = dropout_adj( edge_index, p=self.pretraining_noise_ratio, force_undirected=is_undirected(edge_index), num_nodes=x.size(0), training=self.training) if size is None and torch.is_tensor(x): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # [N, F0] * [F0, heads * F] = [N, heads * F] if torch.is_tensor(x): x = torch.matmul(x, self.weight) else: x = (None if x[0] is None else torch.matmul(x[0], self.weight), None if x[1] is None else torch.matmul(x[1], self.weight)) propagated = self.propagate(edge_index, size=size, x=x) if (self.is_super_gat and self.training) or (attention_edge_index is not None) or (neg_edge_index is not None): device = next(self.parameters()).device num_pos_samples = int(self.edge_sample_ratio * edge_index.size(1)) num_neg_samples = int(self.neg_sample_ratio * self.edge_sample_ratio * edge_index.size(1)) if attention_edge_index is not None: neg_edge_index = None elif neg_edge_index is not None: pass elif batch is None: if self.to_undirected_at_neg: edge_index_for_ns = to_undirected(edge_index, num_nodes=x.size(0)) else: edge_index_for_ns = edge_index neg_edge_index = negative_sampling( edge_index=edge_index_for_ns, num_nodes=x.size(0), num_neg_samples=num_neg_samples, ) else: neg_edge_index = batched_negative_sampling( edge_index=edge_index, batch=batch, num_neg_samples=num_neg_samples, ) if self.edge_sample_ratio < 1.0: pos_indices = random.sample(range(edge_index.size(1)), num_pos_samples) pos_indices = torch.tensor(pos_indices).long().to(device) pos_edge_index = edge_index[:, pos_indices] else: pos_edge_index = edge_index att_with_negatives = self._get_attention_with_negatives( x=x, edge_index=pos_edge_index, neg_edge_index=neg_edge_index, total_edge_index=attention_edge_index, ) # [E + neg_E, heads] # Labels if self.training and (self.cache["att_label"] is None or not self.cache_label): att_label = torch.zeros( att_with_negatives.size(0)).float().to(device) att_label[:pos_edge_index.size(1)] = 1. elif self.training and self.cache["att_label"] is not None: att_label = self.cache["att_label"] else: att_label = None self._update_cache("att_label", att_label) self._update_cache("att_with_negatives", att_with_negatives) return propagated