def test_get_normalized_matrix(self): """Tests that the normalized matrix is computed correctly.""" domain = domains.FixedLengthDiscreteDomain( vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2) freq_dict = { 'A': { 'A': 5, 'B': 3, 'C': 1 }, 'B': { 'A': 3, 'B': 5, 'C': 1 }, 'C': { 'A': 1, 'B': 1, 'C': 1 } } matrix = utils.get_normalized_matrix(domain, freq_dict) expected_matrix = [[1, 0.5, 0], [ 0.5, 1, 0, ], [0, 0, 0]] self.assertAllEqual(matrix, expected_matrix)
def test_soft_accuracy(self): """Tests that soft accuracy is computed correctly.""" domain = domains.FixedLengthDiscreteDomain( vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2) targets = np.array([[0, 1]]) logits = np.log([[[0.9, 0.1], [0.6, 0.4]]]) freq_dict = { 'A': { 'A': 5, 'B': 3, 'C': 1 }, 'B': { 'A': 3, 'B': 5, 'C': 1 }, 'C': { 'A': 1, 'B': 1, 'C': 1 } } accuracy, denominator = utils.compute_weighted_soft_accuracy( logits, targets, weights=None, matrix=utils.get_normalized_matrix(domain, freq_dict)) self.assertEqual(accuracy / denominator, 0.75)
def test_bos_does_not_appear_in_fixed_len_output(self): """Tests that BOS is overridden in fixed length domain samples.""" domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=10) lm = lm_cls(domain=domain) samples = lm.sample(10) for sample in samples: self.assertNotIn(lm.bos_token, sample)
def test_sampling_with_repetition_penalty(self, normalize): """Tests that the repetition penalty affects diversity.""" length = 4 domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=['a', 'b', 'c', 'd'], include_bos=True), length=length) percent_repeats = [] for repetition_penalty in [1, 100, 1 / 100]: lm = lm_cls(domain=domain, repetition_penalty=repetition_penalty, repetition_penalty_normalize=normalize) batch_size = 100 prompt_token = domain.vocab.tokens.index('a') prompt = jnp.concatenate([ jnp.ones((batch_size, 1)).astype(jnp.int32) * lm.bos_token, jnp.ones((batch_size, 1)).astype(jnp.int32) * prompt_token ], axis=1) samples = lm.sample_with_prompt(prompt) samples_str = domain.decode(samples) logging.info('samples: %s', str(samples_str)) num_repeats = 0 for sample in samples_str: num_repeats += sum([ sample[:i].count(sample[i]) > 0 for i in np.arange(1, length) ]) percent_repeats.append(num_repeats / (batch_size * (length - 1))) logging.info('percent_repeats: %s', str(percent_repeats)) self.assertGreater(percent_repeats[0] - percent_repeats[1], 0.1) self.assertGreater(percent_repeats[2] - percent_repeats[0], 0.1)
def _test_domain(): vocab = domains.Vocabulary( tokens=['a', 'b', 'c'], include_bos=True, include_mask=True, include_pad=True) return domains.FixedLengthDiscreteDomain(vocab=vocab, length=3)
def test_flaxlm_evaluation(self): """Tests that FlaxLM evaluation runs.""" domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=range(2), include_bos=True), length=1) eval_data = np.array([[0, 1], [1, 0]]) eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, )) lm = lm_cls(domain=domain) evaluation.evaluate(lm, eval_ds)
def test_empirical_baseline_construction(self): """Tests that EmpiricalBaseline construction is correct.""" domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=range(3), include_bos=True), length=2) train_data = np.array([[0, 1], [1, 0]]) train_ds = tf.data.Dataset.from_tensor_slices((train_data, )) eb = evaluation.EmpiricalBaseline(domain, train_ds, alpha=0) self.assertAllEqual(eb._empirical_dist, [0.5, 0.5, 0])
def test_count_params(self): domain = domains.FixedLengthDiscreteDomain(length=4, vocab_size=2) lm = lm_cls(domain=domain) count = utils.param_count(lm) self.assertEqual(13059, count) # Check these methods run. utils.param_pprint(lm) sizes = utils.param_reduce(lm, log=True) self.assertIsInstance(sizes, dict)
def setUp(self): cls = functools.partial( models.FlaxModel, pmap=False, with_bos=True, output_head=('logits', 'regression'), **lm_cfg) self._domain = domains.FixedLengthDiscreteDomain(length=6, vocab_size=4) lm = cls(domain=self._domain) self.lm = lm super().setUp()
def test_empirical_baseline_evaluation(self): """Tests that EmpiricalBaseline evaluation is correct.""" domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=range(2), include_bos=True), length=1) train_data = np.array([[0, 1], [1, 0]]) train_ds = tf.data.Dataset.from_tensor_slices((train_data, )) eval_data = np.array([[0, 1], [1, 0]]) eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, )) eb = evaluation.EmpiricalBaseline(domain, train_ds) metrics = evaluation.evaluate(eb, eval_ds) self.assertAllEqual(np.asarray(metrics['accuracy']), 0.5) self.assertAllClose(np.asarray(metrics['perplexity']), 2) self.assertAllClose(np.asarray(metrics['loss']), 0.69, atol=0.1)
def test_output_head(self, output_head, multiple_heads): domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2) inputs = domain.sample_uniformly(8) lm = lm_cls(domain=domain, pmap=False) outputs = models.predict_step(lm.optimizer.target, inputs, preprocess_fn=lm.preprocess, output_head=output_head) if multiple_heads: self.assertIsInstance(outputs, dict) self.assertLen(outputs, len(output_head)) else: # We should have gotten a single output, the logits. self.assertEqual(outputs.shape, (inputs.shape[0], inputs.shape[1], lm.vocab_size))
def test_positional_encodings(self, positional_encoding_module): """Tests that the model runs with both types of positional encodings.""" domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2) lm = lm_cls(domain=domain, positional_encoding_module=positional_encoding_module) lm.sample(1)
def _test_domain(): return domains.FixedLengthDiscreteDomain(length=3, vocab_size=4)