def test_task(self): config = electra_task.ElectraPretrainConfig( model=electra.ElectraPretrainerConfig( generator_encoder=encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)), discriminator_encoder=encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)), num_masked_tokens=20, sequence_length=128, cls_heads=[ bert.ClsHeadConfig(inner_dim=10, num_classes=2, name="next_sentence") ]), train_data=pretrain_dataloader.BertPretrainDataConfig( input_path="dummy", max_predictions_per_seq=20, seq_length=128, global_batch_size=1)) task = electra_task.ElectraPretrainTask(config) model = task.build_model() metrics = task.build_metrics() dataset = task.build_inputs(config.train_data) iterator = iter(dataset) optimizer = tf.keras.optimizers.SGD(lr=0.1) task.train_step(next(iterator), model, optimizer, metrics=metrics) task.validation_step(next(iterator), model, metrics=metrics)
class ElectraPretrainConfig(cfg.TaskConfig): """The model config.""" model: electra.ElectraPretrainerConfig = electra.ElectraPretrainerConfig( cls_heads=[ bert.ClsHeadConfig(inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') ]) train_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()