Example #1
0
    def test_train_model_locally(self):
        """Tests training a model locally with one step."""
        teacher_model = model_builder.build_bert_pretrainer(
            pretrainer_cfg=self.experiment_params.teacher_model,
            name='teacher')
        _ = teacher_model(teacher_model.inputs)
        student_model = model_builder.build_bert_pretrainer(
            pretrainer_cfg=self.experiment_params.student_model,
            name='student')
        _ = student_model(student_model.inputs)
        trainer = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer(
            teacher_model=teacher_model,
            student_model=student_model,
            strategy=self.strategy,
            experiment_params=self.experiment_params)

        # Rebuild dummy dataset since loading real dataset will cause timeout error.
        trainer.train_dataset = _dummy_dataset()
        trainer.eval_dataset = _dummy_dataset()
        train_dataset_iter = iter(trainer.train_dataset)
        eval_dataset_iter = iter(trainer.eval_dataset)
        trainer.train_loop_begin()

        trainer.train_step(train_dataset_iter)
        trainer.eval_step(eval_dataset_iter)
Example #2
0
 def test_initialization_with_encoder(self):
     """Initializes pretrainer model with an existing encoder network."""
     encoder = encoders.build_encoder(config=encoders.EncoderConfig(
         type='mobilebert'))
     pretrainer = model_builder.build_bert_pretrainer(
         pretrainer_cfg=self.pretrainer_config, encoder=encoder)
     encoder_network = pretrainer.encoder_network
     self.assertEqual(encoder_network, encoder)
 def setUp(self):
     super(ExportTfliteSquadTest, self).setUp()
     experiment_params = params.EdgeTPUBERTCustomParams()
     pretrainer_model = model_builder.build_bert_pretrainer(
         experiment_params.student_model, name='pretrainer')
     encoder_network = pretrainer_model.encoder_network
     self.span_labeler = models.BertSpanLabeler(
         network=encoder_network,
         initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01))
    def test_load_checkpoint(self):
        """Test the pretrained model can be successfully loaded."""
        experiment_params = params.EdgeTPUBERTCustomParams()
        student_pretrainer = experiment_params.student_model
        student_pretrainer.encoder.type = 'mobilebert'
        pretrainer = model_builder.build_bert_pretrainer(
            pretrainer_cfg=student_pretrainer, name='test_model')
        # Makes sure the pretrainer variables are created.
        checkpoint_path = self.create_tempfile().full_path
        _ = pretrainer(pretrainer.inputs)
        pretrainer.save_weights(checkpoint_path)

        utils.load_checkpoint(pretrainer, checkpoint_path)
Example #5
0
 def test_default_initialization(self):
     """Initializes pretrainer model from stratch."""
     pretrainer = model_builder.build_bert_pretrainer(
         pretrainer_cfg=self.pretrainer_config, name='test_model')
     # Makes sure the pretrainer variables are created.
     _ = pretrainer(pretrainer.inputs)
     self.assertEqual(pretrainer.name, 'test_model')
     encoder = pretrainer.encoder_network
     default_number_layer = encoders.MobileBertEncoderConfig().num_blocks
     encoder_transformer_layer_counter = 0
     for layer in encoder.layers:
         if isinstance(layer, modeling.layers.MobileBertTransformer):
             encoder_transformer_layer_counter += 1
     self.assertEqual(default_number_layer,
                      encoder_transformer_layer_counter)
Example #6
0
 def test_initialization_with_mlm(self):
     """Initializes pretrainer model with an existing MLM head."""
     embedding = modeling.layers.MobileBertEmbedding(
         word_vocab_size=30522,
         word_embed_size=128,
         type_vocab_size=2,
         output_embed_size=encoders.MobileBertEncoderConfig().hidden_size)
     dummy_input = tf.keras.layers.Input(shape=(None, ), dtype=tf.int32)
     _ = embedding(dummy_input)
     embedding_table = embedding.word_embedding.embeddings
     mlm_layer = modeling.layers.MobileBertMaskedLM(
         embedding_table=embedding_table)
     pretrainer = model_builder.build_bert_pretrainer(
         pretrainer_cfg=self.pretrainer_config, masked_lm=mlm_layer)
     mlm_network = pretrainer.masked_lm
     self.assertEqual(mlm_network, mlm_layer)
Example #7
0
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Set up experiment params and load the configs from file/files.
    experiment_params = params.EdgeTPUBERTCustomParams()
    experiment_params = utils.config_override(experiment_params, FLAGS)

    # change the input mask type to tf.float32 to avoid additional casting op.
    experiment_params.student_model.encoder.mobilebert.input_mask_dtype = 'float32'

    # Experiments indicate using -120 as the mask value for Softmax is good enough
    # for both int8 and bfloat. So we set quantization_friendly to True for both
    # quant and float model.
    pretrainer_model = model_builder.build_bert_pretrainer(
        experiment_params.student_model,
        name='pretrainer',
        quantization_friendly=True)

    encoder_network = pretrainer_model.encoder_network
    model = models.BertSpanLabeler(
        network=encoder_network,
        initializer=tf.keras.initializers.TruncatedNormal(stddev=0.01))

    # Load model weights.
    if FLAGS.model_checkpoint is not None:
        checkpoint_dict = {'model': model}
        checkpoint = tf.train.Checkpoint(**checkpoint_dict)
        checkpoint.restore(
            FLAGS.model_checkpoint).assert_existing_objects_matched()

    model_for_serving = build_model_for_serving(model)
    model_for_serving.summary()

    # TODO(b/194449109): Need to save the model to file and then convert tflite
    # with 'tf.lite.TFLiteConverter.from_saved_model()' to get the expected
    # accuracy
    tmp_dir = tempfile.TemporaryDirectory().name
    model_for_serving.save(tmp_dir)

    def _representative_dataset():
        dataset_params = question_answering_dataloader.QADataConfig()
        dataset_params.input_path = SQUAD_TRAIN_SPLIT
        dataset_params.drop_remainder = False
        dataset_params.global_batch_size = 1
        dataset_params.is_training = True

        dataset = orbit.utils.make_distributed_dataset(
            tf.distribute.get_strategy(), build_inputs, dataset_params)
        for example in dataset.take(100):
            inputs = example[0]
            input_word_ids = inputs['input_word_ids']
            input_mask = inputs['input_mask']
            input_type_ids = inputs['input_type_ids']
            yield [input_word_ids, input_mask, input_type_ids]

    converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
    if FLAGS.quantization_method in ['full-integer', 'hybrid']:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
    if FLAGS.quantization_method in ['full-integer']:
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS_INT8
        ]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.float32
        converter.representative_dataset = _representative_dataset

    tflite_quant_model = converter.convert()
    export_model_path = os.path.join(FLAGS.export_path, 'model.tflite')
    with tf.io.gfile.GFile(export_model_path, 'wb') as f:
        f.write(tflite_quant_model)
    logging.info('Successfully save the tflite to %s', FLAGS.export_path)
Example #8
0
def main(_):

    # Set up experiment params and load the configs from file/files.
    experiment_params = params.EdgeTPUBERTCustomParams()
    experiment_params = utils.config_override(experiment_params, FLAGS)
    model_dir = utils.get_model_dir(experiment_params, FLAGS)

    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=experiment_params.runtime.distribution_strategy,
        all_reduce_alg=experiment_params.runtime.all_reduce_alg,
        num_gpus=experiment_params.runtime.num_gpus,
        tpu_address=experiment_params.runtime.tpu_address)

    with distribution_strategy.scope():
        teacher_model = model_builder.build_bert_pretrainer(
            pretrainer_cfg=experiment_params.teacher_model,
            quantization_friendly=False,
            name='teacher')
        student_model = model_builder.build_bert_pretrainer(
            pretrainer_cfg=experiment_params.student_model,
            quantization_friendly=True,
            name='student')

        # Load model weights.
        teacher_ckpt_dir_or_file = experiment_params.teacher_model_init_checkpoint
        if not teacher_ckpt_dir_or_file:
            raise ValueError(
                '`teacher_model_init_checkpoint` is not specified.')
        utils.load_checkpoint(teacher_model, teacher_ckpt_dir_or_file)

        student_ckpt_dir_or_file = experiment_params.student_model_init_checkpoint
        if not student_ckpt_dir_or_file:
            # Makes sure the pretrainer variables are created.
            _ = student_model(student_model.inputs)
            logging.warn(
                'No student checkpoint is provided, training might take '
                'much longer before converging.')
        else:
            utils.load_checkpoint(student_model, student_ckpt_dir_or_file)

        runner = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer(
            teacher_model=teacher_model,
            student_model=student_model,
            strategy=distribution_strategy,
            experiment_params=experiment_params,
            export_ckpt_path=model_dir)

        # Save checkpoint for preemption handling.
        # Checkpoint for downstreaming tasks are saved separately inside the
        # runner's train_loop_end() function.
        checkpoint = tf.train.Checkpoint(
            teacher_model=runner.teacher_model,
            student_model=runner.student_model,
            layer_wise_optimizer=runner.layer_wise_optimizer,
            e2e_optimizer=runner.e2e_optimizer,
            current_step=runner.current_step)
        checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            directory=model_dir,
            max_to_keep=5,
            step_counter=runner.current_step,
            checkpoint_interval=20000,
            init_fn=None)

    controller = orbit.Controller(
        trainer=runner,
        evaluator=runner,
        global_step=runner.current_step,
        strategy=distribution_strategy,
        steps_per_loop=experiment_params.orbit_config.steps_per_loop,
        summary_dir=os.path.join(model_dir, 'train'),
        eval_summary_dir=os.path.join(model_dir, 'eval'),
        checkpoint_manager=checkpoint_manager)

    if FLAGS.mode == 'train':
        controller.train(steps=experiment_params.orbit_config.total_steps)
    else:
        raise ValueError('Unsupported mode, only support `train`')