def test_compute_logprob(self): domain = _test_domain() lm = lm_cls(domain=domain) seq = domain.sample_uniformly(16, seed=0) metrics = lm.evaluate_batch(seq) log_likelihoods = models.compute_logprob(seq, lm) # evaluate_batch() returns an array of per-batch metrics # for each device. Here, we use [0] because the tests run using only # one device. self.assertAllClose(metrics['loss'][0], -jnp.sum(log_likelihoods))
def _compute_logprob(inputs, model, weights): model.set_weights(weights) return models.compute_logprob(inputs, model, mask_token=None)