Example #1
0
def get_exp_config():
    return cfg.ExperimentConfig(
        task=cfg.TaskConfig(model=bert.PretrainerConfig()),
        trainer=trainer_lib.ProgressiveTrainerConfig(
            export_checkpoint=True,
            export_checkpoint_interval=1,
            export_only_final_stage_ckpt=False))
Example #2
0
def get_exp_config():
    """Get ExperimentConfig."""
    params = cfg.ExperimentConfig(
        task=distillation.BertDistillationTaskConfig(
            train_data=pretrain_dataloader.BertPretrainDataConfig(),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                is_training=False)),
        trainer=prog_trainer_lib.ProgressiveTrainerConfig(
            progressive=distillation.BertDistillationProgressiveConfig(),
            optimizer_config=optimization_config,
            train_steps=740000,
            checkpoint_interval=20000))

    return config_override(params, FLAGS)
Example #3
0
def get_exp_config():
    """Get ExperimentConfig."""

    params = cfg.ExperimentConfig(
        task=masked_lm.MaskedLMConfig(
            train_data=pretrain_dataloader.BertPretrainDataConfig(),
            small_train_data=pretrain_dataloader.BertPretrainDataConfig(),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                is_training=False)),
        trainer=prog_trainer_lib.ProgressiveTrainerConfig(
            progressive=masked_lm.ProgStackingConfig(),
            optimizer_config=BertOptimizationConfig(),
            train_steps=1000000),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])

    return utils.config_override(params, FLAGS)
Example #4
0
    def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
        config = cfg.ExperimentConfig(
            task=cfg.TaskConfig(model=bert.PretrainerConfig()),
            runtime=cfg.RuntimeConfig(
                mixed_precision_dtype=mixed_precision_dtype,
                loss_scale=loss_scale),
            trainer=trainer_lib.ProgressiveTrainerConfig(
                export_checkpoint=True,
                export_checkpoint_interval=1,
                export_only_final_stage_ckpt=False))
        task = TestPolicy(None, config.task)
        trainer = trainer_lib.ProgressiveTrainer(config, task,
                                                 self.get_temp_dir())
        if mixed_precision_dtype != 'float16':
            self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
        elif mixed_precision_dtype == 'float16' and loss_scale is None:
            self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)

        metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
        self.assertIn('training_loss', metrics)
Example #5
0
    def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
        model_dir = self.get_temp_dir()
        experiment_config = cfg.ExperimentConfig(
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
            task=ProgTaskConfig())
        experiment_config = params_dict.override_params_dict(experiment_config,
                                                             self._test_config,
                                                             is_strict=False)

        with distribution_strategy.scope():
            task = task_factory.get_task(experiment_config.task,
                                         logging_dir=model_dir)

        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode=flag_mode,
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)

        if run_post_eval:
            self.assertNotEmpty(logs)
        else:
            self.assertEmpty(logs)

        if flag_mode == 'eval':
            return
        self.assertNotEmpty(
            tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
        # Tests continuous evaluation.
        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode='continuous_eval',
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)
        print(logs)
Example #6
0
    def prepare_config(self, teacher_block_num, student_block_num,
                       transfer_teacher_layers):
        # using small model for testing
        task_config = distillation.BertDistillationTaskConfig(
            teacher_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=teacher_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='gelu'),
            student_model=bert.PretrainerConfig(encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=student_block_num)),
                                                cls_heads=[
                                                    bert.ClsHeadConfig(
                                                        inner_dim=256,
                                                        num_classes=2,
                                                        dropout_rate=0.1,
                                                        name='next_sentence')
                                                ],
                                                mlm_activation='relu'),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10),
            validation_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path='dummy',
                max_predictions_per_seq=76,
                seq_length=512,
                global_batch_size=10))

        # set only 1 step for each stage
        progressive_config = distillation.BertDistillationProgressiveConfig()
        progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
            transfer_teacher_layers)
        progressive_config.layer_wise_distill_config.num_steps = 1
        progressive_config.pretrain_distill_config.num_steps = 1

        optimization_config = optimization.OptimizationConfig(
            optimizer=optimization.OptimizerConfig(
                type='lamb',
                lamb=optimization.LAMBConfig(weight_decay_rate=0.0001,
                                             exclude_from_weight_decay=[
                                                 'LayerNorm', 'layer_norm',
                                                 'bias', 'no_norm'
                                             ])),
            learning_rate=optimization.LrConfig(
                type='polynomial',
                polynomial=optimization.PolynomialLrConfig(
                    initial_learning_rate=1.5e-3,
                    decay_steps=10000,
                    end_learning_rate=1.5e-3)),
            warmup=optimization.WarmupConfig(
                type='linear',
                linear=optimization.LinearWarmupConfig(
                    warmup_learning_rate=0)))

        exp_config = cfg.ExperimentConfig(
            task=task_config,
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(
                progressive=progressive_config,
                optimizer_config=optimization_config))

        # Create a teacher model checkpoint.
        teacher_encoder = encoders.build_encoder(
            task_config.teacher_model.encoder)
        pretrainer_config = task_config.teacher_model
        if pretrainer_config.cls_heads:
            teacher_cls_heads = [
                layers.ClassificationHead(**cfg.as_dict())
                for cfg in pretrainer_config.cls_heads
            ]
        else:
            teacher_cls_heads = []

        masked_lm = layers.MobileBertMaskedLM(
            embedding_table=teacher_encoder.get_embedding_table(),
            activation=tf_utils.get_activation(
                pretrainer_config.mlm_activation),
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=pretrainer_config.mlm_initializer_range),
            name='cls/predictions')
        teacher_pretrainer = models.BertPretrainerV2(
            encoder_network=teacher_encoder,
            classification_heads=teacher_cls_heads,
            customized_masked_lm=masked_lm)

        # The model variables will be created after the forward call.
        _ = teacher_pretrainer(teacher_pretrainer.inputs)
        teacher_pretrainer_ckpt = tf.train.Checkpoint(
            **teacher_pretrainer.checkpoint_items)
        teacher_ckpt_path = os.path.join(self.get_temp_dir(),
                                         'teacher_model.ckpt')
        teacher_pretrainer_ckpt.save(teacher_ckpt_path)
        exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()

        return exp_config