def test_model_input_output_shape(self): vocab_size = 5 sequence_length = 7 model = char_prediction_models.create_recurrent_model( vocab_size, sequence_length) self.assertEqual(model.input_shape, (None, sequence_length)) self.assertEqual(model.output_shape, (None, sequence_length, vocab_size))
def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=char_prediction_models.create_recurrent_model( vocab_size=VOCAB_LENGTH, sequence_length=sequence_length), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), input_spec=task_datasets.element_type_structure, metrics=[ keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]) ])
def test_no_mask_zero_results_in_correct_mask(self): mask_model = char_prediction_models.create_recurrent_model( vocab_size=3, sequence_length=3, mask_zero=False) data = tf.constant([[0, 1, 1]]) output_mask = mask_model.compute_mask(data, mask=None) self.assertIsNone(output_mask)
def test_create_recurrent_model_raises_on_nonpositive_sequence_length(self): with self.assertRaisesRegex(ValueError, 'sequence_length must be a positive integer'): char_prediction_models.create_recurrent_model( vocab_size=3, sequence_length=0)