def test_all_identity(self): """Test no-mask case (maintaining original values).""" xs = self._xs v = self._domain.vocab masker = models.BertMasker( self._domain, mask_rate=1.0, mask_token_proportion=0.0, random_token_proportion=0.0) for k in range(10): rng = jrandom.PRNGKey(k) inputs, outputs, weights = masker(xs, rng=rng, mode=models.Mode.train) self.assertAllEqual(xs, inputs) self.assertAllEqual(xs != v.pad, weights) self.assertAllEqual(xs, outputs)
def test_all_mask(self): """Test masking with MASK values.""" xs = self._xs v = self._domain.vocab masker = models.BertMasker( self._domain, mask_rate=1.0, mask_token_proportion=1.0, random_token_proportion=0.0) for k in range(10): rng = jrandom.PRNGKey(k) inputs, outputs, weights = masker(xs, rng=rng, mode=models.Mode.train) self.assertAllEqual((xs == v.pad), (inputs == v.pad)) self.assertAllEqual((xs != v.pad), (inputs == v.mask)) self.assertAllEqual(xs != v.pad, weights) self.assertAllEqual(xs, outputs)
def test_all_normal(self): """Test masking with random values.""" xs = self._xs v = self._domain.vocab # Check masker = models.BertMasker( self._domain, mask_rate=1.0, mask_token_proportion=0.0, random_token_proportion=1.0) for k in range(10): rng = jrandom.PRNGKey(k) inputs, outputs, weights = masker(xs, rng=rng, mode=models.Mode.train) is_normal = np.isin(inputs, masker._normal_tokens) self.assertAllEqual((xs == v.pad), (inputs == v.pad)) self.assertAllEqual((xs != v.pad), is_normal) self.assertAllEqual(xs != v.pad, weights) self.assertAllEqual(xs, outputs)