def test_sample_negatives(self):
        batch_size = 2
        sequence_length = 10
        hidden_size = 4
        num_negatives = 3

        features = (
            torch.arange(sequence_length * hidden_size, device=torch_device) //
            hidden_size).view(
                sequence_length,
                hidden_size)  # each value in vector consits of same value
        features = features[None, :].expand(batch_size, sequence_length,
                                            hidden_size).contiguous()

        negatives = Wav2Vec2ForPreTraining._sample_negatives(
            features, num_negatives)

        self.assertTrue(negatives.shape == (num_negatives, batch_size,
                                            sequence_length, hidden_size))

        # make sure no negatively sampled vector is actually a positive one
        for negative in negatives:
            self.assertTrue(((negative - features) == 0).sum() == 0.0)

        # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
        self.assertTrue(
            negatives.unique(dim=-1).shape,
            (num_negatives, batch_size, sequence_length, 1))
    def test_sample_negatives_with_attn_mask(self):
        batch_size = 2
        sequence_length = 10
        hidden_size = 4
        num_negatives = 3

        # second half of last input tensor is padded
        attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
        attention_mask[-1, sequence_length // 2 :] = 0

        features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
            sequence_length, hidden_size
        )  # each value in vector consits of same value
        features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()

        # replace masked feature vectors with -100 to test that those are not sampled
        features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)

        negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)

        self.assertTrue((negatives >= 0).all().item())

        self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))

        # make sure no negatively sampled vector is actually a positive one
        for negative in negatives:
            self.assertTrue(((negative - features) == 0).sum() == 0.0)

        # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
        self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))