예제 #1
0
 def setUp(self) -> None:
     """Set up the test case with a triples factory."""
     self.triples_factory = Nations().training
     self.batch_size = 20
     self.num_epochs = 10
     self.graph_sampler = GraphSampler(
         mapped_triples=self.triples_factory.mapped_triples,
         batch_size=self.batch_size,
     )
예제 #2
0
 def setUp(self) -> None:
     """Set up the test case with a triples factory."""
     self.triples_factory = Nations().training
     self.num_samples = 20
     self.num_epochs = 10
     self.graph_sampler = GraphSampler(triples_factory=self.triples_factory,
                                       num_samples=self.num_samples)
예제 #3
0
class GraphSamplerTest(unittest.TestCase):
    """Test the GraphSampler."""
    def setUp(self) -> None:
        """Set up the test case with a triples factory."""
        self.triples_factory = Nations().training
        self.batch_size = 20
        self.num_epochs = 10
        self.graph_sampler = GraphSampler(
            mapped_triples=self.triples_factory.mapped_triples,
            batch_size=self.batch_size,
        )

    def test_sample(self) -> None:
        """Test drawing samples from GraphSampler."""
        batch = torch.as_tensor(list(self.graph_sampler.sample_batch()))

        # check shape
        assert batch.shape == (self.batch_size, )

        # get triples
        triples_batch = self.triples_factory.mapped_triples[batch]

        # check connected components
        # super inefficient
        components = [{int(e)}
                      for e in torch.cat([triples_batch[:, i]
                                          for i in (0, 2)]).unique()]
        for h, _, t in triples_batch:
            h, t = int(h), int(t)

            s_comp_ind = [i for i, c in enumerate(components) if h in c][0]
            o_comp_ind = [i for i, c in enumerate(components) if t in c][0]

            # join
            if s_comp_ind != o_comp_ind:
                s_comp = components.pop(max(s_comp_ind, o_comp_ind))
                o_comp = components.pop(min(s_comp_ind, o_comp_ind))
                so_comp = s_comp.union(o_comp)
                components.append(so_comp)
            else:
                pass
                # already joined

            if len(components) < 2:
                break

        # check that there is only a single component
        assert len(components) == 1