Beispiel #1
0
 def test_forward_one_batch(self):
     comparator = DotComparator()
     lhs_pos = torch.tensor(
         [[[0.8931, 0.2241, 0.4241], [0.6557, 0.2492, 0.4157]]], requires_grad=True
     )
     rhs_pos = torch.tensor(
         [[[0.9220, 0.2892, 0.7408], [0.1476, 0.6079, 0.1835]]], requires_grad=True
     )
     lhs_neg = torch.tensor(
         [
             [
                 [0.3836, 0.7648, 0.0965],
                 [0.8929, 0.8947, 0.4877],
                 [0.0000, 0.0000, 0.0000],
                 [0.7967, 0.6736, 0.2966],
             ]
         ],
         requires_grad=True,
     )
     rhs_neg = torch.tensor(
         [
             [
                 [0.6116, 0.6010, 0.9500],
                 [0.0000, 0.0000, 0.0000],
                 [0.2360, 0.5923, 0.7536],
                 [0.1290, 0.3088, 0.2731],
             ]
         ],
         requires_grad=True,
     )
     pos_scores, lhs_neg_scores, rhs_neg_scores = comparator(
         comparator.prepare(lhs_pos),
         comparator.prepare(rhs_pos),
         comparator.prepare(lhs_neg),
         comparator.prepare(rhs_neg),
     )
     self.assertTensorEqual(pos_scores, torch.tensor([[1.2024, 0.3246]]))
     self.assertTensorEqual(
         lhs_neg_scores,
         torch.tensor(
             [[[0.6463, 1.4433, 0.0000, 1.1491], [0.5392, 0.7652, 0.0000, 0.5815]]]
         ),
     )
     self.assertTensorEqual(
         rhs_neg_scores,
         torch.tensor(
             [[[1.0838, 0.0000, 0.6631, 0.3002], [0.9457, 0.0000, 0.6156, 0.2751]]]
         ),
     )
     (pos_scores.sum() + lhs_neg_scores.sum() + rhs_neg_scores.sum()).backward()
     self.assertTrue((lhs_pos.grad != 0).any())
     self.assertTrue((rhs_pos.grad != 0).any())
     self.assertTrue((lhs_neg.grad != 0).any())
     self.assertTrue((rhs_neg.grad != 0).any())
print(embedding_user_0.shape)
print(embedding_all.shape)

from torchbiggraph.model import DotComparator
src_entity_offset = dictionary["entities"]["user_id"].index("0")
dest_1_entity_offset = dictionary["entities"]["user_id"].index("7")
dest_2_entity_offset = dictionary["entities"]["user_id"].index("135")

with h5py.File("model/demo/embeddings_user_id_0.v10.h5", "r") as hf:
    src_embedding = hf["embeddings"][src_entity_offset, :]
    dest_1_embedding = hf["embeddings"][dest_1_entity_offset, :]
    dest_2_embedding = hf["embeddings"][dest_2_entity_offset, :]
    dest_embeddings = hf["embeddings"][...]

import torch
comparator = DotComparator()

scores_1, _, _ = comparator(
    comparator.prepare(torch.tensor(src_embedding.reshape([1,1,520]))),
    comparator.prepare(torch.tensor(dest_1_embedding.reshape([1,1,520]))),
    torch.empty(1, 0, 520),  # Left-hand side negatives, not needed
    torch.empty(1, 0, 520),  # Right-hand side negatives, not needed
)

scores_2, _, _ = comparator(
    comparator.prepare(torch.tensor(src_embedding.reshape([1,1,520]))),
    comparator.prepare(torch.tensor(dest_2_embedding.reshape([1,1,520]))),
    torch.empty(1, 0, 520),  # Left-hand side negatives, not needed
    torch.empty(1, 0, 520),  # Right-hand side negatives, not needed
)
print("Now let's do some simple things within torch:")

from torchbiggraph.model import DotComparator
src_entity_offset = dictionary["entities"]["user_id"].index("0")
dest_1_entity_offset = dictionary["entities"]["user_id"].index("7")
dest_2_entity_offset = dictionary["entities"]["user_id"].index("1")

with h5py.File("model/example_2/embeddings_user_id_0.v10.h5", "r") as hf:
    src_embedding = hf["embeddings"][src_entity_offset, :]
    dest_1_embedding = hf["embeddings"][dest_1_entity_offset, :]
    dest_2_embedding = hf["embeddings"][dest_2_entity_offset, :]
    dest_embeddings = hf["embeddings"][...]

import torch
comparator = DotComparator()

scores_1, _, _ = comparator(
    comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))),
    comparator.prepare(torch.tensor(dest_1_embedding.reshape([1,1,10]))),
    torch.empty(1, 0, 10),  # Left-hand side negatives, not needed
    torch.empty(1, 0, 10),  # Right-hand side negatives, not needed
)

scores_2, _, _ = comparator(
    comparator.prepare(torch.tensor(src_embedding.reshape([1,1,10]))),
    comparator.prepare(torch.tensor(dest_2_embedding.reshape([1,1,10]))),
    torch.empty(1, 0, 10),  # Left-hand side negatives, not needed
    torch.empty(1, 0, 10),  # Right-hand side negatives, not needed
)