Exemplo n.º 1
0
    def test_task(self, teacher_block_num, student_block_num,
                  transfer_teacher_layers):
        exp_config = self.prepare_config(teacher_block_num, student_block_num,
                                         transfer_teacher_layers)
        bert_distillation_task = distillation.BertDistillationTask(
            strategy=tf.distribute.get_strategy(),
            progressive=exp_config.trainer.progressive,
            optimizer_config=exp_config.trainer.optimizer_config,
            task_config=exp_config.task)
        metrics = bert_distillation_task.build_metrics()
        train_dataset = bert_distillation_task.get_train_dataset(stage_id=0)
        train_iterator = iter(train_dataset)

        eval_dataset = bert_distillation_task.get_eval_dataset(stage_id=0)
        eval_iterator = iter(eval_dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)

        # test train/val step for all stages, including the last pretraining stage
        for stage in range(student_block_num + 1):
            step = stage
            bert_distillation_task.update_pt_stage(step)
            model = bert_distillation_task.get_model(stage, None)
            bert_distillation_task.initialize(model)
            bert_distillation_task.train_step(next(train_iterator),
                                              model,
                                              optimizer,
                                              metrics=metrics)
            bert_distillation_task.validation_step(next(eval_iterator),
                                                   model,
                                                   metrics=metrics)

        logging.info('begin to save and load model checkpoint')
        ckpt = tf.train.Checkpoint(model=model)
        ckpt.save(self.get_temp_dir())
Exemplo n.º 2
0
def main(_):
    logging.info('Parsing config files...')
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = get_exp_config()

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype,
            params.runtime.loss_scale,
            use_experimental_api=True)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    with distribution_strategy.scope():
        task = distillation.BertDistillationTask(
            strategy=distribution_strategy,
            progressive=params.trainer.progressive,
            optimizer_config=params.trainer.optimizer_config,
            task_config=params.task)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=FLAGS.model_dir)