def test_task(self):
        config = electra_task.ElectraPretrainConfig(
            model=electra.ElectraPretrainerConfig(
                generator_encoder=encoders.EncoderConfig(
                    bert=encoders.BertEncoderConfig(vocab_size=30522,
                                                    num_layers=1)),
                discriminator_encoder=encoders.EncoderConfig(
                    bert=encoders.BertEncoderConfig(vocab_size=30522,
                                                    num_layers=1)),
                num_masked_tokens=20,
                sequence_length=128,
                cls_heads=[
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence")
                ]),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path="dummy",
                max_predictions_per_seq=20,
                seq_length=128,
                global_batch_size=1))
        task = electra_task.ElectraPretrainTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)
示例#2
0
class ElectraPretrainConfig(cfg.TaskConfig):
    """The model config."""
    model: electra.ElectraPretrainerConfig = electra.ElectraPretrainerConfig(
        cls_heads=[
            bert.ClsHeadConfig(inner_dim=768,
                               num_classes=2,
                               dropout_rate=0.1,
                               name='next_sentence')
        ])
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()