예제 #1
0
class _NegativeSamplingTestCase:
    """A test case for quickly defining common tests for samplers."""

    #: The batch size
    batch_size: int
    #: The random seed
    seed: int
    #: The triples factory
    triples_factory: TriplesFactory
    #: The sLCWA instances
    slcwa_instances: SLCWAInstances
    #: Class of negative sampling to test
    negative_sampling_cls: ClassVar[Type[NegativeSampler]]
    #: The negative sampler instance, initialized in setUp
    negative_sampler: NegativeSampler
    #: A positive batch
    positive_batch: torch.LongTensor

    def setUp(self) -> None:
        """Set up the test case with a triples factory and model."""
        self.batch_size = 16
        self.seed = 42
        self.num_negs_per_pos = 10
        self.triples_factory = Nations().training
        self.slcwa_instances = self.triples_factory.create_slcwa_instances()
        self.negative_sampler = self.negative_sampling_cls(triples_factory=self.triples_factory)
        self.scaling_negative_sampler = self.negative_sampling_cls(
            triples_factory=self.triples_factory,
            num_negs_per_pos=self.num_negs_per_pos,
        )
        random = numpy.random.RandomState(seed=self.seed)
        batch_indices = random.randint(low=0, high=self.slcwa_instances.num_instances, size=(self.batch_size,))
        self.positive_batch = self.slcwa_instances.mapped_triples[batch_indices]

    def test_sample(self) -> None:
        # Generate negative sample
        negative_batch = self.negative_sampler.sample(positive_batch=self.positive_batch)

        # check shape
        assert negative_batch.shape == self.positive_batch.shape

        # check bounds: heads
        assert _array_check_bounds(negative_batch[:, 0], low=0, high=self.triples_factory.num_entities)

        # check bounds: relations
        assert _array_check_bounds(negative_batch[:, 1], low=0, high=self.triples_factory.num_relations)

        # check bounds: tails
        assert _array_check_bounds(negative_batch[:, 2], low=0, high=self.triples_factory.num_entities)

        # Check that all elements got corrupted
        assert (negative_batch != self.positive_batch).any(dim=1).all()

        # Generate scaled negative sample
        scaled_negative_batch = self.scaling_negative_sampler.sample(
            positive_batch=self.positive_batch,
        )

        assert scaled_negative_batch.shape[0] == self.positive_batch.shape[0] * self.num_negs_per_pos
        assert scaled_negative_batch.shape[1] == self.positive_batch.shape[1]
예제 #2
0
파일: cases.py 프로젝트: tgebhart/pykeen
class NegativeSamplerGenericTestCase(
        unittest_templates.GenericTestCase[NegativeSampler]):
    """A test case for quickly defining common tests for samplers."""

    #: The batch size
    batch_size: int = 16
    #: The random seed
    seed: int = 42
    #: The triples factory
    triples_factory: TriplesFactory
    #: The instances
    training_instances: Instances
    #: A positive batch
    positive_batch: torch.LongTensor
    #: Kwargs
    kwargs = {
        'num_negs_per_pos': 10,
    }

    def pre_setup_hook(self) -> None:
        """Set up the test case with a triples factory, training instances, and a default positive batch."""
        self.triples_factory = Nations().training
        self.training_instances = self.triples_factory.create_slcwa_instances()
        random_state = numpy.random.RandomState(seed=self.seed)
        batch_indices = random_state.randint(low=0,
                                             high=len(self.training_instances),
                                             size=(self.batch_size, ))
        self.positive_batch = self.training_instances.mapped_triples[
            batch_indices]

    def _pre_instantiation_hook(
        self, kwargs: MutableMapping[str, Any]
    ) -> MutableMapping[str, Any]:  # noqa: D102
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        kwargs['triples_factory'] = self.triples_factory
        return kwargs

    def check_sample(self, instance: NegativeSampler) -> None:
        """Test generating a negative sample."""
        # Generate negative sample
        negative_batch, batch_filter = instance.sample(
            positive_batch=self.positive_batch)

        # check filter shape if necessary
        if instance.filterer is not None:
            assert batch_filter is not None
            assert batch_filter.shape == (self.batch_size,
                                          instance.num_negs_per_pos)
            assert batch_filter.dtype == torch.bool
        else:
            assert batch_filter is None

        # check shape
        assert negative_batch.shape == (self.positive_batch.shape[0],
                                        instance.num_negs_per_pos, 3)

        # check bounds: heads
        assert _array_check_bounds(negative_batch[..., 0],
                                   low=0,
                                   high=self.triples_factory.num_entities)

        # check bounds: relations
        assert _array_check_bounds(negative_batch[..., 1],
                                   low=0,
                                   high=self.triples_factory.num_relations)

        # check bounds: tails
        assert _array_check_bounds(negative_batch[..., 2],
                                   low=0,
                                   high=self.triples_factory.num_entities)

        if instance.filterer is not None:
            positive_batch = self.positive_batch.unsqueeze(dim=1).repeat(
                1, instance.num_negs_per_pos, 1)
            positive_batch = positive_batch[batch_filter]
            negative_batch = negative_batch[batch_filter]

            # test that the negative triple is not the original positive triple
            assert (negative_batch != positive_batch).any(dim=-1).all()

    def test_sample_no_filter(self) -> None:
        """Test generating a negative sample."""
        self.check_sample(self.instance)

    def test_sample_filtered(self) -> None:
        """Test generating a negative sample with filtering."""
        filterer = PythonSetFilterer(
            mapped_triples=self.triples_factory.mapped_triples)
        instance = self.cls(**self.instance_kwargs, filterer=filterer)
        self.check_sample(instance)

    def _update_positive_batch(self, positive_batch, batch_filter):
        # shape: (batch_size, 1, num_neg)
        positive_batch = positive_batch.unsqueeze(dim=1)

        if batch_filter is not None:
            positive_batch = positive_batch[batch_filter]
        return positive_batch

    def test_small_batch(self):
        """Test on a small batch."""
        self.instance.sample(positive_batch=self.positive_batch[:1])