def test_get_normalized_matrix(self): """Tests that the normalized matrix is computed correctly.""" domain = domains.FixedLengthDiscreteDomain( vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2) freq_dict = { 'A': { 'A': 5, 'B': 3, 'C': 1 }, 'B': { 'A': 3, 'B': 5, 'C': 1 }, 'C': { 'A': 1, 'B': 1, 'C': 1 } } matrix = utils.get_normalized_matrix(domain, freq_dict) expected_matrix = [[1, 0.5, 0], [ 0.5, 1, 0, ], [0, 0, 0]] self.assertAllEqual(matrix, expected_matrix)
def test_soft_accuracy(self): """Tests that soft accuracy is computed correctly.""" domain = domains.FixedLengthDiscreteDomain( vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2) targets = np.array([[0, 1]]) logits = np.log([[[0.9, 0.1], [0.6, 0.4]]]) freq_dict = { 'A': { 'A': 5, 'B': 3, 'C': 1 }, 'B': { 'A': 3, 'B': 5, 'C': 1 }, 'C': { 'A': 1, 'B': 1, 'C': 1 } } accuracy, denominator = utils.compute_weighted_soft_accuracy( logits, targets, weights=None, matrix=utils.get_normalized_matrix(domain, freq_dict)) self.assertEqual(accuracy / denominator, 0.75)
def test_overfit(self): domain = domains.VariableLengthDiscreteDomain( vocab=domains.Vocabulary( tokens=['a', 'b', 'c'], include_bos=True, include_eos=True), length=9) seqs = [ list('abcabcab'), list('bbbbbb'), list('cbacbacb'), ] enc = domain.encode(seqs, pad=True) self.assertAllEqual( [[0, 1, 2, 0, 1, 2, 0, 1, 4], [1, 1, 1, 1, 1, 1, 4, 4, 4], [2, 1, 0, 2, 1, 0, 2, 1, 4] ], enc) enc = np.array(enc) model = lm_cls( domain=domain, learning_rate=0.01, dropout_rate=0.0, attention_dropout_rate=0.0) for _ in range(100): metrics = model.fit_batch(enc) # 2 less than perfect because the first token is unpredictable given just # <BOS>, and there are 3 total examples. denom = metrics['denominator'][0] correct = metrics['accuracy'][0] self.assertEqual((denom - 2)/denom, correct / denom)
def test_sampling_with_repetition_penalty(self, normalize): """Tests that the repetition penalty affects diversity.""" length = 4 domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=['a', 'b', 'c', 'd'], include_bos=True), length=length) percent_repeats = [] for repetition_penalty in [1, 100, 1 / 100]: lm = lm_cls(domain=domain, repetition_penalty=repetition_penalty, repetition_penalty_normalize=normalize) batch_size = 100 prompt_token = domain.vocab.tokens.index('a') prompt = jnp.concatenate([ jnp.ones((batch_size, 1)).astype(jnp.int32) * lm.bos_token, jnp.ones((batch_size, 1)).astype(jnp.int32) * prompt_token ], axis=1) samples = lm.sample_with_prompt(prompt) samples_str = domain.decode(samples) logging.info('samples: %s', str(samples_str)) num_repeats = 0 for sample in samples_str: num_repeats += sum([ sample[:i].count(sample[i]) > 0 for i in np.arange(1, length) ]) percent_repeats.append(num_repeats / (batch_size * (length - 1))) logging.info('percent_repeats: %s', str(percent_repeats)) self.assertGreater(percent_repeats[0] - percent_repeats[1], 0.1) self.assertGreater(percent_repeats[2] - percent_repeats[0], 0.1)
def _test_domain(): vocab = domains.Vocabulary( tokens=['a', 'b', 'c'], include_bos=True, include_mask=True, include_pad=True) return domains.FixedLengthDiscreteDomain(vocab=vocab, length=3)
def test_flaxlm_evaluation(self): """Tests that FlaxLM evaluation runs.""" domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=range(2), include_bos=True), length=1) eval_data = np.array([[0, 1], [1, 0]]) eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, )) lm = lm_cls(domain=domain) evaluation.evaluate(lm, eval_ds)
def test_empirical_baseline_construction(self): """Tests that EmpiricalBaseline construction is correct.""" domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=range(3), include_bos=True), length=2) train_data = np.array([[0, 1], [1, 0]]) train_ds = tf.data.Dataset.from_tensor_slices((train_data, )) eb = evaluation.EmpiricalBaseline(domain, train_ds, alpha=0) self.assertAllEqual(eb._empirical_dist, [0.5, 0.5, 0])
def _make_pretrained_transformer(self, **kwargs): """Trains a transformer to produce strings of alternating a's and b's.""" seqs = ['abab', 'baba'] * 64 domain = domains.VariableLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=['a', 'b'], include_bos=True, include_eos=True), length=len(seqs[0])) enc_seqs = np.array(domain.encode(seqs, pad=False)) lm = lm_cls(domain=domain, learning_rate=0.001, **kwargs) lm.fit(enc_seqs, batch_size=len(enc_seqs), epochs=20) return lm, domain
def test_bos_does_not_appear_in_var_len_output(self): """Tests that BOS is not used for padding in var-len domain samples.""" domain = domains.VariableLengthDiscreteDomain( vocab=domains.Vocabulary(tokens=[0, 1], include_eos=True), length=10, ) lm = lm_cls(domain=domain) samples = lm.sample(10) for sample in samples: self.assertNotIn(lm.bos_token, sample)
def test_only_eos_after_eos(self): """Tests that the characters found after EOS are all equal to EOS.""" domain = domains.VariableLengthDiscreteDomain( vocab=domains.Vocabulary(tokens=[0, 1], include_eos=True), length=10, ) lm = lm_cls(domain=domain) samples = lm.sample(10) for sample in samples: if lm.eos_token in sample: start_eos = np.argwhere(sample == lm.eos_token)[0][0] self.assertAllEqual(sample[start_eos:], [lm.eos_token] * (len(sample) - start_eos))
def test_empirical_baseline_evaluation(self): """Tests that EmpiricalBaseline evaluation is correct.""" domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary( tokens=range(2), include_bos=True), length=1) train_data = np.array([[0, 1], [1, 0]]) train_ds = tf.data.Dataset.from_tensor_slices((train_data, )) eval_data = np.array([[0, 1], [1, 0]]) eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, )) eb = evaluation.EmpiricalBaseline(domain, train_ds) metrics = evaluation.evaluate(eb, eval_ds) self.assertAllEqual(np.asarray(metrics['accuracy']), 0.5) self.assertAllClose(np.asarray(metrics['perplexity']), 2) self.assertAllClose(np.asarray(metrics['loss']), 0.69, atol=0.1)