예제 #1
0
 def test_compute_logprob(self):
   domain = _test_domain()
   lm = lm_cls(domain=domain)
   seq = domain.sample_uniformly(16, seed=0)
   metrics = lm.evaluate_batch(seq)
   log_likelihoods = models.compute_logprob(seq, lm)
   # evaluate_batch() returns an array of per-batch metrics
   # for each device. Here, we use [0] because the tests run using only
   # one device.
   self.assertAllClose(metrics['loss'][0], -jnp.sum(log_likelihoods))
예제 #2
0
def _compute_logprob(inputs, model, weights):
    model.set_weights(weights)
    return models.compute_logprob(inputs, model, mask_token=None)