def test_train_model_locally(self): """Tests training a model locally with one step.""" teacher_model = model_builder.build_bert_pretrainer( pretrainer_cfg=self.experiment_params.teacher_model, name='teacher') _ = teacher_model(teacher_model.inputs) student_model = model_builder.build_bert_pretrainer( pretrainer_cfg=self.experiment_params.student_model, name='student') _ = student_model(student_model.inputs) trainer = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer( teacher_model=teacher_model, student_model=student_model, strategy=self.strategy, experiment_params=self.experiment_params) # Rebuild dummy dataset since loading real dataset will cause timeout error. trainer.train_dataset = _dummy_dataset() trainer.eval_dataset = _dummy_dataset() train_dataset_iter = iter(trainer.train_dataset) eval_dataset_iter = iter(trainer.eval_dataset) trainer.train_loop_begin() trainer.train_step(train_dataset_iter) trainer.eval_step(eval_dataset_iter)
def test_initialization_with_encoder(self): """Initializes pretrainer model with an existing encoder network.""" encoder = encoders.build_encoder(config=encoders.EncoderConfig( type='mobilebert')) pretrainer = model_builder.build_bert_pretrainer( pretrainer_cfg=self.pretrainer_config, encoder=encoder) encoder_network = pretrainer.encoder_network self.assertEqual(encoder_network, encoder)
def setUp(self): super(ExportTfliteSquadTest, self).setUp() experiment_params = params.EdgeTPUBERTCustomParams() pretrainer_model = model_builder.build_bert_pretrainer( experiment_params.student_model, name='pretrainer') encoder_network = pretrainer_model.encoder_network self.span_labeler = models.BertSpanLabeler( network=encoder_network, initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01))
def test_load_checkpoint(self): """Test the pretrained model can be successfully loaded.""" experiment_params = params.EdgeTPUBERTCustomParams() student_pretrainer = experiment_params.student_model student_pretrainer.encoder.type = 'mobilebert' pretrainer = model_builder.build_bert_pretrainer( pretrainer_cfg=student_pretrainer, name='test_model') # Makes sure the pretrainer variables are created. checkpoint_path = self.create_tempfile().full_path _ = pretrainer(pretrainer.inputs) pretrainer.save_weights(checkpoint_path) utils.load_checkpoint(pretrainer, checkpoint_path)
def test_default_initialization(self): """Initializes pretrainer model from stratch.""" pretrainer = model_builder.build_bert_pretrainer( pretrainer_cfg=self.pretrainer_config, name='test_model') # Makes sure the pretrainer variables are created. _ = pretrainer(pretrainer.inputs) self.assertEqual(pretrainer.name, 'test_model') encoder = pretrainer.encoder_network default_number_layer = encoders.MobileBertEncoderConfig().num_blocks encoder_transformer_layer_counter = 0 for layer in encoder.layers: if isinstance(layer, modeling.layers.MobileBertTransformer): encoder_transformer_layer_counter += 1 self.assertEqual(default_number_layer, encoder_transformer_layer_counter)
def test_initialization_with_mlm(self): """Initializes pretrainer model with an existing MLM head.""" embedding = modeling.layers.MobileBertEmbedding( word_vocab_size=30522, word_embed_size=128, type_vocab_size=2, output_embed_size=encoders.MobileBertEncoderConfig().hidden_size) dummy_input = tf.keras.layers.Input(shape=(None, ), dtype=tf.int32) _ = embedding(dummy_input) embedding_table = embedding.word_embedding.embeddings mlm_layer = modeling.layers.MobileBertMaskedLM( embedding_table=embedding_table) pretrainer = model_builder.build_bert_pretrainer( pretrainer_cfg=self.pretrainer_config, masked_lm=mlm_layer) mlm_network = pretrainer.masked_lm self.assertEqual(mlm_network, mlm_layer)
def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Set up experiment params and load the configs from file/files. experiment_params = params.EdgeTPUBERTCustomParams() experiment_params = utils.config_override(experiment_params, FLAGS) # change the input mask type to tf.float32 to avoid additional casting op. experiment_params.student_model.encoder.mobilebert.input_mask_dtype = 'float32' # Experiments indicate using -120 as the mask value for Softmax is good enough # for both int8 and bfloat. So we set quantization_friendly to True for both # quant and float model. pretrainer_model = model_builder.build_bert_pretrainer( experiment_params.student_model, name='pretrainer', quantization_friendly=True) encoder_network = pretrainer_model.encoder_network model = models.BertSpanLabeler( network=encoder_network, initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01)) # Load model weights. if FLAGS.model_checkpoint is not None: checkpoint_dict = {'model': model} checkpoint = tf.train.Checkpoint(**checkpoint_dict) checkpoint.restore( FLAGS.model_checkpoint).assert_existing_objects_matched() model_for_serving = build_model_for_serving(model) model_for_serving.summary() # TODO(b/194449109): Need to save the model to file and then convert tflite # with 'tf.lite.TFLiteConverter.from_saved_model()' to get the expected # accuracy tmp_dir = tempfile.TemporaryDirectory().name model_for_serving.save(tmp_dir) def _representative_dataset(): dataset_params = question_answering_dataloader.QADataConfig() dataset_params.input_path = SQUAD_TRAIN_SPLIT dataset_params.drop_remainder = False dataset_params.global_batch_size = 1 dataset_params.is_training = True dataset = orbit.utils.make_distributed_dataset( tf.distribute.get_strategy(), build_inputs, dataset_params) for example in dataset.take(100): inputs = example[0] input_word_ids = inputs['input_word_ids'] input_mask = inputs['input_mask'] input_type_ids = inputs['input_type_ids'] yield [input_word_ids, input_mask, input_type_ids] converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir) if FLAGS.quantization_method in ['full-integer', 'hybrid']: converter.optimizations = [tf.lite.Optimize.DEFAULT] if FLAGS.quantization_method in ['full-integer']: converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.float32 converter.representative_dataset = _representative_dataset tflite_quant_model = converter.convert() export_model_path = os.path.join(FLAGS.export_path, 'model.tflite') with tf.io.gfile.GFile(export_model_path, 'wb') as f: f.write(tflite_quant_model) logging.info('Successfully save the tflite to %s', FLAGS.export_path)
def main(_): # Set up experiment params and load the configs from file/files. experiment_params = params.EdgeTPUBERTCustomParams() experiment_params = utils.config_override(experiment_params, FLAGS) model_dir = utils.get_model_dir(experiment_params, FLAGS) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=experiment_params.runtime.distribution_strategy, all_reduce_alg=experiment_params.runtime.all_reduce_alg, num_gpus=experiment_params.runtime.num_gpus, tpu_address=experiment_params.runtime.tpu_address) with distribution_strategy.scope(): teacher_model = model_builder.build_bert_pretrainer( pretrainer_cfg=experiment_params.teacher_model, quantization_friendly=False, name='teacher') student_model = model_builder.build_bert_pretrainer( pretrainer_cfg=experiment_params.student_model, quantization_friendly=True, name='student') # Load model weights. teacher_ckpt_dir_or_file = experiment_params.teacher_model_init_checkpoint if not teacher_ckpt_dir_or_file: raise ValueError( '`teacher_model_init_checkpoint` is not specified.') utils.load_checkpoint(teacher_model, teacher_ckpt_dir_or_file) student_ckpt_dir_or_file = experiment_params.student_model_init_checkpoint if not student_ckpt_dir_or_file: # Makes sure the pretrainer variables are created. _ = student_model(student_model.inputs) logging.warn( 'No student checkpoint is provided, training might take ' 'much longer before converging.') else: utils.load_checkpoint(student_model, student_ckpt_dir_or_file) runner = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer( teacher_model=teacher_model, student_model=student_model, strategy=distribution_strategy, experiment_params=experiment_params, export_ckpt_path=model_dir) # Save checkpoint for preemption handling. # Checkpoint for downstreaming tasks are saved separately inside the # runner's train_loop_end() function. checkpoint = tf.train.Checkpoint( teacher_model=runner.teacher_model, student_model=runner.student_model, layer_wise_optimizer=runner.layer_wise_optimizer, e2e_optimizer=runner.e2e_optimizer, current_step=runner.current_step) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=5, step_counter=runner.current_step, checkpoint_interval=20000, init_fn=None) controller = orbit.Controller( trainer=runner, evaluator=runner, global_step=runner.current_step, strategy=distribution_strategy, steps_per_loop=experiment_params.orbit_config.steps_per_loop, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'eval'), checkpoint_manager=checkpoint_manager) if FLAGS.mode == 'train': controller.train(steps=experiment_params.orbit_config.total_steps) else: raise ValueError('Unsupported mode, only support `train`')