Beispiel #1
0
 def test_task_with_hub(self):
   hub_module_url = self._export_bert_tfhub()
   config = dual_encoder.DualEncoderConfig(
       hub_module_url=hub_module_url,
       model=self.get_model_config(),
       train_data=self._train_data_config)
   self._run_task(config)
def labse_train() -> cfg.ExperimentConfig:
    r"""Language-agnostic bert sentence embedding.

  *Note*: this experiment does not use cross-accelerator global softmax so it
  does not reproduce the exact LABSE training.
  """
    config = cfg.ExperimentConfig(
        task=dual_encoder.DualEncoderConfig(
            train_data=dual_encoder_dataloader.DualEncoderDataConfig(),
            validation_data=dual_encoder_dataloader.DualEncoderDataConfig(
                is_training=False, drop_remainder=False)),
        trainer=cfg.TrainerConfig(optimizer_config=LaBSEOptimizationConfig(
            learning_rate=optimization.LrConfig(type="polynomial",
                                                polynomial=PolynomialLr(
                                                    initial_learning_rate=3e-5,
                                                    end_learning_rate=0.0)),
            warmup=optimization.WarmupConfig(
                type="polynomial", polynomial=PolynomialWarmupConfig()))),
        restrictions=[
            "task.train_data.is_training != None",
            "task.validation_data.is_training != None"
        ])
    return config
    def test_task(self):
        config = dual_encoder.DualEncoderConfig(
            init_checkpoint=self.get_temp_dir(),
            model=self.get_model_config(),
            train_data=self._train_data_config)
        task = dual_encoder.DualEncoderTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)

        # Saves a checkpoint.
        pretrain_cfg = bert.PretrainerConfig(encoder=encoders.EncoderConfig(
            bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
        pretrain_model = masked_lm.MaskedLMTask(None).build_model(pretrain_cfg)
        ckpt = tf.train.Checkpoint(model=pretrain_model,
                                   **pretrain_model.checkpoint_items)
        ckpt.save(config.init_checkpoint)
        task.initialize(model)