Exemplo n.º 1
0
 def batch_begin(self, l: Learner):
     with torch.no_grad():
         batch_in = l.batch_in.to(
             "cpu")  # Should be a tensor of indices, shape (batch_size, k)
         batch_size, k = batch_in.shape
         num_neg_samples = batch_in.shape[0] * self.ratio
         negative_samples = batch_in.repeat(
             self.ratio, 1)  # shape (batch_size * ratio, k)
         negative_probs = torch.zeros(num_neg_samples).to(
             l.batch_out.device)
         negative_samples.scatter_(
             1, torch.randint(k, (num_neg_samples, 1)),
             torch.randint(self.num_entities, (num_neg_samples, 1)))
         negative_samples = negative_samples.to(l.batch_in.device)
         l.batch_in = torch.cat((l.batch_in, negative_samples), 0)
         l.batch_out = torch.cat((l.batch_out, negative_probs), 0)