def test_sentence_prediction(self): config = sentence_prediction.SentencePredictionConfig( model=sentence_prediction.ModelConfig( encoder=encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)), num_classes=2)) task = sentence_prediction.SentencePredictionTask(config) model = task.build_model() ckpt = tf.train.Checkpoint(model=model) ckpt_path = ckpt.save(self.get_temp_dir()) export_module_cls = export_savedmodel.lookup_export_module(task) serving_params = {"inputs_only": False} params = export_module_cls.Params(**serving_params) export_module = export_module_cls(params=params, model=model) export_dir = export_savedmodel_util.export( export_module, function_keys=["serve"], checkpoint_path=ckpt_path, export_savedmodel_dir=self.get_temp_dir()) imported = tf.saved_model.load(export_dir) serving_fn = imported.signatures["serving_default"] dummy_ids = tf.ones((1, 5), dtype=tf.int32) inputs = dict( input_word_ids=dummy_ids, input_mask=dummy_ids, input_type_ids=dummy_ids) ref_outputs = model(inputs) outputs = serving_fn(**inputs) self.assertAllClose(ref_outputs, outputs["outputs"]) self.assertEqual(outputs["outputs"].shape, (1, 2))
def test_sentence_prediction_text(self, inputs_only): vocab_file_path = os.path.join(self.get_temp_dir(), "vocab.txt") _create_fake_vocab_file(vocab_file_path) config = sentence_prediction.SentencePredictionConfig( model=sentence_prediction.ModelConfig( encoder=encoders.EncoderConfig(bert=encoders.BertEncoderConfig( vocab_size=30522, num_layers=1)), num_classes=2)) task = sentence_prediction.SentencePredictionTask(config) model = task.build_model() params = serving_modules.SentencePrediction.Params( inputs_only=inputs_only, parse_sequence_length=10, text_fields=["foo", "bar"], vocab_file=vocab_file_path) export_module = serving_modules.SentencePrediction(params=params, model=model) examples = _create_fake_serialized_examples({ "foo": b"hello world", "bar": b"hello world" }) functions = export_module.get_inference_signatures({ "serve_text_examples": "serving_default", }) outputs = functions["serving_default"](examples) self.assertEqual(outputs["outputs"].shape, (10, 2))
def teams_sentence_prediction() -> cfg.ExperimentConfig: r"""Teams GLUE.""" config = cfg.ExperimentConfig( task=sentence_prediction.SentencePredictionConfig( model=sentence_prediction.ModelConfig( encoder=encoders.EncoderConfig( type="any", any=teams.TeamsEncoderConfig(num_layers=1))), train_data=sentence_prediction_dataloader .SentencePredictionDataConfig(), validation_data=sentence_prediction_dataloader .SentencePredictionDataConfig( is_training=False, drop_remainder=False)), trainer=cfg.TrainerConfig(optimizer_config=TeamsOptimizationConfig()), restrictions=[ "task.train_data.is_training != None", "task.validation_data.is_training != None" ]) return config
def roformer_glue() -> cfg.ExperimentConfig: r"""BigBird GLUE.""" config = cfg.ExperimentConfig( task=sentence_prediction.SentencePredictionConfig( model=sentence_prediction.ModelConfig( encoder=encoders.EncoderConfig( type='any', any=roformer.RoformerEncoderConfig())), train_data=sentence_prediction_dataloader .SentencePredictionDataConfig(), validation_data=sentence_prediction_dataloader .SentencePredictionDataConfig( is_training=False, drop_remainder=False)), trainer=cfg.TrainerConfig( optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adamw', 'adamw': { 'weight_decay_rate': 0.01, 'exclude_from_weight_decay': ['LayerNorm', 'layer_norm', 'bias'], } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 3e-5, 'end_learning_rate': 0.0, } }, 'warmup': { 'type': 'polynomial' } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def get_model_config(self, num_classes): return sentence_prediction.ModelConfig( encoder=encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)), num_classes=num_classes)
def get_model_config(self, num_classes): return sentence_prediction.ModelConfig( encoder=encoders.TransformerEncoderConfig( vocab_size=30522, num_layers=1), num_classes=num_classes)
def test_sentence_prediction(self, use_v2_feature_names): if use_v2_feature_names: input_word_ids_field = "input_word_ids" input_type_ids_field = "input_type_ids" else: input_word_ids_field = "input_ids" input_type_ids_field = "segment_ids" config = sentence_prediction.SentencePredictionConfig( model=sentence_prediction.ModelConfig( encoder=encoders.EncoderConfig(bert=encoders.BertEncoderConfig( vocab_size=30522, num_layers=1)), num_classes=2)) task = sentence_prediction.SentencePredictionTask(config) model = task.build_model() params = serving_modules.SentencePrediction.Params( inputs_only=True, parse_sequence_length=10, use_v2_feature_names=use_v2_feature_names) export_module = serving_modules.SentencePrediction(params=params, model=model) functions = export_module.get_inference_signatures({ "serve": "serving_default", "serve_examples": "serving_examples" }) self.assertSameElements(functions.keys(), ["serving_default", "serving_examples"]) dummy_ids = tf.ones((10, 10), dtype=tf.int32) outputs = functions["serving_default"](dummy_ids) self.assertEqual(outputs["outputs"].shape, (10, 2)) params = serving_modules.SentencePrediction.Params( inputs_only=False, parse_sequence_length=10, use_v2_feature_names=use_v2_feature_names) export_module = serving_modules.SentencePrediction(params=params, model=model) functions = export_module.get_inference_signatures({ "serve": "serving_default", "serve_examples": "serving_examples" }) outputs = functions["serving_default"](input_word_ids=dummy_ids, input_mask=dummy_ids, input_type_ids=dummy_ids) self.assertEqual(outputs["outputs"].shape, (10, 2)) dummy_ids = tf.ones((10, ), dtype=tf.int32) examples = _create_fake_serialized_examples({ input_word_ids_field: dummy_ids, "input_mask": dummy_ids, input_type_ids_field: dummy_ids }) outputs = functions["serving_examples"](examples) self.assertEqual(outputs["outputs"].shape, (10, 2)) with self.assertRaises(ValueError): _ = export_module.get_inference_signatures({"foo": None})