예제 #1
0
 def test_all_identity(self):
   """Test no-mask case (maintaining original values)."""
   xs = self._xs
   v = self._domain.vocab
   masker = models.BertMasker(
       self._domain,
       mask_rate=1.0,
       mask_token_proportion=0.0,
       random_token_proportion=0.0)
   for k in range(10):
     rng = jrandom.PRNGKey(k)
     inputs, outputs, weights = masker(xs, rng=rng, mode=models.Mode.train)
     self.assertAllEqual(xs, inputs)
     self.assertAllEqual(xs != v.pad, weights)
     self.assertAllEqual(xs, outputs)
예제 #2
0
 def test_all_mask(self):
   """Test masking with MASK values."""
   xs = self._xs
   v = self._domain.vocab
   masker = models.BertMasker(
       self._domain,
       mask_rate=1.0,
       mask_token_proportion=1.0,
       random_token_proportion=0.0)
   for k in range(10):
     rng = jrandom.PRNGKey(k)
     inputs, outputs, weights = masker(xs, rng=rng, mode=models.Mode.train)
     self.assertAllEqual((xs == v.pad), (inputs == v.pad))
     self.assertAllEqual((xs != v.pad), (inputs == v.mask))
     self.assertAllEqual(xs != v.pad, weights)
     self.assertAllEqual(xs, outputs)
예제 #3
0
 def test_all_normal(self):
   """Test masking with random values."""
   xs = self._xs
   v = self._domain.vocab
   # Check
   masker = models.BertMasker(
       self._domain,
       mask_rate=1.0,
       mask_token_proportion=0.0,
       random_token_proportion=1.0)
   for k in range(10):
     rng = jrandom.PRNGKey(k)
     inputs, outputs, weights = masker(xs, rng=rng, mode=models.Mode.train)
     is_normal = np.isin(inputs, masker._normal_tokens)
     self.assertAllEqual((xs == v.pad), (inputs == v.pad))
     self.assertAllEqual((xs != v.pad), is_normal)
     self.assertAllEqual(xs != v.pad, weights)
     self.assertAllEqual(xs, outputs)