예제 #1
0
 def test_get_normalized_matrix(self):
     """Tests that the normalized matrix is computed correctly."""
     domain = domains.FixedLengthDiscreteDomain(
         vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2)
     freq_dict = {
         'A': {
             'A': 5,
             'B': 3,
             'C': 1
         },
         'B': {
             'A': 3,
             'B': 5,
             'C': 1
         },
         'C': {
             'A': 1,
             'B': 1,
             'C': 1
         }
     }
     matrix = utils.get_normalized_matrix(domain, freq_dict)
     expected_matrix = [[1, 0.5, 0], [
         0.5,
         1,
         0,
     ], [0, 0, 0]]
     self.assertAllEqual(matrix, expected_matrix)
예제 #2
0
 def test_soft_accuracy(self):
     """Tests that soft accuracy is computed correctly."""
     domain = domains.FixedLengthDiscreteDomain(
         vocab=domains.Vocabulary(tokens=['A', 'B', 'C']), length=2)
     targets = np.array([[0, 1]])
     logits = np.log([[[0.9, 0.1], [0.6, 0.4]]])
     freq_dict = {
         'A': {
             'A': 5,
             'B': 3,
             'C': 1
         },
         'B': {
             'A': 3,
             'B': 5,
             'C': 1
         },
         'C': {
             'A': 1,
             'B': 1,
             'C': 1
         }
     }
     accuracy, denominator = utils.compute_weighted_soft_accuracy(
         logits,
         targets,
         weights=None,
         matrix=utils.get_normalized_matrix(domain, freq_dict))
     self.assertEqual(accuracy / denominator, 0.75)
예제 #3
0
 def test_bos_does_not_appear_in_fixed_len_output(self):
   """Tests that BOS is overridden in fixed length domain samples."""
   domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=10)
   lm = lm_cls(domain=domain)
   samples = lm.sample(10)
   for sample in samples:
     self.assertNotIn(lm.bos_token, sample)
예제 #4
0
 def test_sampling_with_repetition_penalty(self, normalize):
     """Tests that the repetition penalty affects diversity."""
     length = 4
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=['a', 'b', 'c', 'd'], include_bos=True),
                                                length=length)
     percent_repeats = []
     for repetition_penalty in [1, 100, 1 / 100]:
         lm = lm_cls(domain=domain,
                     repetition_penalty=repetition_penalty,
                     repetition_penalty_normalize=normalize)
         batch_size = 100
         prompt_token = domain.vocab.tokens.index('a')
         prompt = jnp.concatenate([
             jnp.ones((batch_size, 1)).astype(jnp.int32) * lm.bos_token,
             jnp.ones((batch_size, 1)).astype(jnp.int32) * prompt_token
         ],
                                  axis=1)
         samples = lm.sample_with_prompt(prompt)
         samples_str = domain.decode(samples)
         logging.info('samples: %s', str(samples_str))
         num_repeats = 0
         for sample in samples_str:
             num_repeats += sum([
                 sample[:i].count(sample[i]) > 0
                 for i in np.arange(1, length)
             ])
         percent_repeats.append(num_repeats / (batch_size * (length - 1)))
     logging.info('percent_repeats: %s', str(percent_repeats))
     self.assertGreater(percent_repeats[0] - percent_repeats[1], 0.1)
     self.assertGreater(percent_repeats[2] - percent_repeats[0], 0.1)
예제 #5
0
def _test_domain():
  vocab = domains.Vocabulary(
      tokens=['a', 'b', 'c'],
      include_bos=True,
      include_mask=True,
      include_pad=True)
  return domains.FixedLengthDiscreteDomain(vocab=vocab, length=3)
예제 #6
0
 def test_flaxlm_evaluation(self):
     """Tests that FlaxLM evaluation runs."""
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=range(2), include_bos=True),
                                                length=1)
     eval_data = np.array([[0, 1], [1, 0]])
     eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, ))
     lm = lm_cls(domain=domain)
     evaluation.evaluate(lm, eval_ds)
예제 #7
0
 def test_empirical_baseline_construction(self):
     """Tests that EmpiricalBaseline construction is correct."""
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=range(3), include_bos=True),
                                                length=2)
     train_data = np.array([[0, 1], [1, 0]])
     train_ds = tf.data.Dataset.from_tensor_slices((train_data, ))
     eb = evaluation.EmpiricalBaseline(domain, train_ds, alpha=0)
     self.assertAllEqual(eb._empirical_dist, [0.5, 0.5, 0])
예제 #8
0
    def test_count_params(self):
        domain = domains.FixedLengthDiscreteDomain(length=4, vocab_size=2)
        lm = lm_cls(domain=domain)
        count = utils.param_count(lm)
        self.assertEqual(13059, count)

        # Check these methods run.
        utils.param_pprint(lm)
        sizes = utils.param_reduce(lm, log=True)
        self.assertIsInstance(sizes, dict)
예제 #9
0
 def setUp(self):
   cls = functools.partial(
       models.FlaxModel,
       pmap=False,
       with_bos=True,
       output_head=('logits', 'regression'),
       **lm_cfg)
   self._domain = domains.FixedLengthDiscreteDomain(length=6, vocab_size=4)
   lm = cls(domain=self._domain)
   self.lm = lm
   super().setUp()
예제 #10
0
 def test_empirical_baseline_evaluation(self):
     """Tests that EmpiricalBaseline evaluation is correct."""
     domain = domains.FixedLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=range(2), include_bos=True),
                                                length=1)
     train_data = np.array([[0, 1], [1, 0]])
     train_ds = tf.data.Dataset.from_tensor_slices((train_data, ))
     eval_data = np.array([[0, 1], [1, 0]])
     eval_ds = tf.data.Dataset.from_tensor_slices((eval_data, ))
     eb = evaluation.EmpiricalBaseline(domain, train_ds)
     metrics = evaluation.evaluate(eb, eval_ds)
     self.assertAllEqual(np.asarray(metrics['accuracy']), 0.5)
     self.assertAllClose(np.asarray(metrics['perplexity']), 2)
     self.assertAllClose(np.asarray(metrics['loss']), 0.69, atol=0.1)
예제 #11
0
 def test_output_head(self, output_head, multiple_heads):
     domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2)
     inputs = domain.sample_uniformly(8)
     lm = lm_cls(domain=domain, pmap=False)
     outputs = models.predict_step(lm.optimizer.target,
                                   inputs,
                                   preprocess_fn=lm.preprocess,
                                   output_head=output_head)
     if multiple_heads:
         self.assertIsInstance(outputs, dict)
         self.assertLen(outputs, len(output_head))
     else:
         # We should have gotten a single output, the logits.
         self.assertEqual(outputs.shape,
                          (inputs.shape[0], inputs.shape[1], lm.vocab_size))
예제 #12
0
 def test_positional_encodings(self, positional_encoding_module):
     """Tests that the model runs with both types of positional encodings."""
     domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2)
     lm = lm_cls(domain=domain,
                 positional_encoding_module=positional_encoding_module)
     lm.sample(1)
예제 #13
0
def _test_domain():
  return domains.FixedLengthDiscreteDomain(length=3, vocab_size=4)