def test_model_input_output_shape(self):
    vocab_size = 5
    sequence_length = 7
    model = char_prediction_models.create_recurrent_model(
        vocab_size, sequence_length)

    self.assertEqual(model.input_shape, (None, sequence_length))
    self.assertEqual(model.output_shape, (None, sequence_length, vocab_size))
Пример #2
0
 def model_fn() -> model.Model:
   return keras_utils.from_keras_model(
       keras_model=char_prediction_models.create_recurrent_model(
           vocab_size=VOCAB_LENGTH, sequence_length=sequence_length),
       loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
       input_spec=task_datasets.element_type_structure,
       metrics=[
           keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
           keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token])
       ])
 def test_no_mask_zero_results_in_correct_mask(self):
   mask_model = char_prediction_models.create_recurrent_model(
       vocab_size=3, sequence_length=3, mask_zero=False)
   data = tf.constant([[0, 1, 1]])
   output_mask = mask_model.compute_mask(data, mask=None)
   self.assertIsNone(output_mask)
 def test_create_recurrent_model_raises_on_nonpositive_sequence_length(self):
   with self.assertRaisesRegex(ValueError,
                               'sequence_length must be a positive integer'):
     char_prediction_models.create_recurrent_model(
         vocab_size=3, sequence_length=0)