def test_dgmc_on_multiple_graphs(): set_seed() model = DGMC(psi_1, psi_2, num_steps=1) batch = Batch.from_data_list([data, data]) x, e, b = batch.x, batch.edge_index, batch.batch set_seed() S1_0, S1_L = model(x, e, None, b, x, e, None, b) assert S1_0.size() == (batch.num_nodes, data.num_nodes) assert S1_L.size() == (batch.num_nodes, data.num_nodes) set_seed() model.k = data.num_nodes # Test a sparse "dense" variant. S2_0, S2_L = model(x, e, None, b, x, e, None, b) assert torch.allclose(S1_0, S2_0.to_dense()) assert torch.allclose(S1_L, S2_L.to_dense())
def test_dgmc_on_single_graphs(): set_seed() model = DGMC(psi_1, psi_2, num_steps=1) x, e = data.x, data.edge_index y = torch.arange(data.num_nodes) y = torch.stack([y, y], dim=0) set_seed() S1_0, S1_L = model(x, e, None, None, x, e, None, None) loss1 = model.loss(S1_0, y) loss1.backward() acc1 = model.acc(S1_0, y) hits1_1 = model.hits_at_k(1, S1_0, y) hits1_10 = model.hits_at_k(10, S1_0, y) hits1_all = model.hits_at_k(data.num_nodes, S1_0, y) set_seed() model.k = data.num_nodes # Test a sparse "dense" variant. y = torch.arange(data.num_nodes) y = torch.stack([y, y], dim=0) S2_0, S2_L = model(x, e, None, None, x, e, None, None, y) loss2 = model.loss(S2_0, y) loss2.backward() acc2 = model.acc(S2_0, y) hits2_1 = model.hits_at_k(1, S2_0, y) hits2_10 = model.hits_at_k(10, S2_0, y) hits2_all = model.hits_at_k(data.num_nodes, S2_0, y) assert S1_0.size() == (data.num_nodes, data.num_nodes) assert S1_L.size() == (data.num_nodes, data.num_nodes) assert torch.allclose(S1_0, S2_0.to_dense()) assert torch.allclose(S1_L, S2_L.to_dense()) assert torch.allclose(loss1, loss2) assert acc1 == acc2 == hits1_1 == hits2_1 assert hits1_1 <= hits1_10 == hits2_10 <= hits1_all assert hits1_all == hits2_all == 1.0