Beispiel #1
0
 def test_get_normalized_matrix(self):
     """Tests that the normalized matrix is computed correctly."""
     domain = domains.FixedLengthDiscreteDomain(
         vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2)
     freq_dict = {
         'A': {
             'A': 5,
             'B': 3,
             'C': 1
         },
         'B': {
             'A': 3,
             'B': 5,
             'C': 1
         },
         'C': {
             'A': 1,
             'B': 1,
             'C': 1
         }
     }
     matrix = utils.get_normalized_matrix(domain, freq_dict)
     expected_matrix = [[1, 0.5, 0], [
         0.5,
         1,
         0,
     ], [0, 0, 0]]
     self.assertAllEqual(matrix, expected_matrix)
Beispiel #2
0
 def test_soft_accuracy(self):
     """Tests that soft accuracy is computed correctly."""
     domain = domains.FixedLengthDiscreteDomain(
         vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2)
     targets = np.array([[0, 1]])
     logits = np.log([[[0.9, 0.1], [0.6, 0.4]]])
     freq_dict = {
         'A': {
             'A': 5,
             'B': 3,
             'C': 1
         },
         'B': {
             'A': 3,
             'B': 5,
             'C': 1
         },
         'C': {
             'A': 1,
             'B': 1,
             'C': 1
         }
     }
     accuracy, denominator = utils.compute_weighted_soft_accuracy(
         logits,
         targets,
         weights=None,
         matrix=utils.get_normalized_matrix(domain, freq_dict))
     self.assertEqual(accuracy / denominator, 0.75)
Beispiel #3
0
  def test_overfit(self):
    domain = domains.VariableLengthDiscreteDomain(
        vocab=domains.Vocabulary(
            tokens=['a', 'b', 'c'], include_bos=True, include_eos=True),
        length=9)
    seqs = [
        list('abcabcab'),
        list('bbbbbb'),
        list('cbacbacb'),
    ]
    enc = domain.encode(seqs, pad=True)
    self.assertAllEqual(
        [[0, 1, 2, 0, 1, 2, 0, 1, 4],
         [1, 1, 1, 1, 1, 1, 4, 4, 4],
         [2, 1, 0, 2, 1, 0, 2, 1, 4]
         ], enc)
    enc = np.array(enc)
    model = lm_cls(
        domain=domain,
        learning_rate=0.01,
        dropout_rate=0.0,
        attention_dropout_rate=0.0)
    for _ in range(100):
      metrics = model.fit_batch(enc)

    # 2 less than perfect because the first token is unpredictable given just
    # <BOS>, and there are 3 total examples.
    denom = metrics['denominator'][0]
    correct = metrics['accuracy'][0]
    self.assertEqual((denom - 2)/denom, correct / denom)
 def test_sampling_with_repetition_penalty(self, normalize):
     """Tests that the repetition penalty affects diversity."""
     length = 4
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=['a', 'b', 'c', 'd'], include_bos=True),
                                                length=length)
     percent_repeats = []
     for repetition_penalty in [1, 100, 1 / 100]:
         lm = lm_cls(domain=domain,
                     repetition_penalty=repetition_penalty,
                     repetition_penalty_normalize=normalize)
         batch_size = 100
         prompt_token = domain.vocab.tokens.index('a')
         prompt = jnp.concatenate([
             jnp.ones((batch_size, 1)).astype(jnp.int32) * lm.bos_token,
             jnp.ones((batch_size, 1)).astype(jnp.int32) * prompt_token
         ],
                                  axis=1)
         samples = lm.sample_with_prompt(prompt)
         samples_str = domain.decode(samples)
         logging.info('samples: %s', str(samples_str))
         num_repeats = 0
         for sample in samples_str:
             num_repeats += sum([
                 sample[:i].count(sample[i]) > 0
                 for i in np.arange(1, length)
             ])
         percent_repeats.append(num_repeats / (batch_size * (length - 1)))
     logging.info('percent_repeats: %s', str(percent_repeats))
     self.assertGreater(percent_repeats[0] - percent_repeats[1], 0.1)
     self.assertGreater(percent_repeats[2] - percent_repeats[0], 0.1)
Beispiel #5
0
def _test_domain():
  vocab = domains.Vocabulary(
      tokens=['a', 'b', 'c'],
      include_bos=True,
      include_mask=True,
      include_pad=True)
  return domains.FixedLengthDiscreteDomain(vocab=vocab, length=3)
Beispiel #6
0
 def test_flaxlm_evaluation(self):
     """Tests that FlaxLM evaluation runs."""
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=range(2), include_bos=True),
                                                length=1)
     eval_data = np.array([[0, 1], [1, 0]])
     eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, ))
     lm = lm_cls(domain=domain)
     evaluation.evaluate(lm, eval_ds)
Beispiel #7
0
 def test_empirical_baseline_construction(self):
     """Tests that EmpiricalBaseline construction is correct."""
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=range(3), include_bos=True),
                                                length=2)
     train_data = np.array([[0, 1], [1, 0]])
     train_ds = tf.data.Dataset.from_tensor_slices((train_data, ))
     eb = evaluation.EmpiricalBaseline(domain, train_ds, alpha=0)
     self.assertAllEqual(eb._empirical_dist, [0.5, 0.5, 0])
Beispiel #8
0
 def _make_pretrained_transformer(self, **kwargs):
     """Trains a transformer to produce strings of alternating a's and b's."""
     seqs = ['abab', 'baba'] * 64
     domain = domains.VariableLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=['a', 'b'], include_bos=True, include_eos=True),
                                                   length=len(seqs[0]))
     enc_seqs = np.array(domain.encode(seqs, pad=False))
     lm = lm_cls(domain=domain, learning_rate=0.001, **kwargs)
     lm.fit(enc_seqs, batch_size=len(enc_seqs), epochs=20)
     return lm, domain
Beispiel #9
0
 def test_bos_does_not_appear_in_var_len_output(self):
   """Tests that BOS is not used for padding in var-len domain samples."""
   domain = domains.VariableLengthDiscreteDomain(
       vocab=domains.Vocabulary(tokens=[0, 1], include_eos=True),
       length=10,
   )
   lm = lm_cls(domain=domain)
   samples = lm.sample(10)
   for sample in samples:
     self.assertNotIn(lm.bos_token, sample)
Beispiel #10
0
 def test_only_eos_after_eos(self):
   """Tests that the characters found after EOS are all equal to EOS."""
   domain = domains.VariableLengthDiscreteDomain(
       vocab=domains.Vocabulary(tokens=[0, 1], include_eos=True),
       length=10,
   )
   lm = lm_cls(domain=domain)
   samples = lm.sample(10)
   for sample in samples:
     if lm.eos_token in sample:
       start_eos = np.argwhere(sample == lm.eos_token)[0][0]
       self.assertAllEqual(sample[start_eos:],
                           [lm.eos_token] * (len(sample) - start_eos))
Beispiel #11
0
 def test_empirical_baseline_evaluation(self):
     """Tests that EmpiricalBaseline evaluation is correct."""
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=range(2), include_bos=True),
                                                length=1)
     train_data = np.array([[0, 1], [1, 0]])
     train_ds = tf.data.Dataset.from_tensor_slices((train_data, ))
     eval_data = np.array([[0, 1], [1, 0]])
     eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, ))
     eb = evaluation.EmpiricalBaseline(domain, train_ds)
     metrics = evaluation.evaluate(eb, eval_ds)
     self.assertAllEqual(np.asarray(metrics['accuracy']), 0.5)
     self.assertAllClose(np.asarray(metrics['perplexity']), 2)
     self.assertAllClose(np.asarray(metrics['loss']), 0.69, atol=0.1)