Пример #1
0
def labse_train() -> cfg.ExperimentConfig:
    r"""Language-agnostic bert sentence embedding.

  *Note*: this experiment does not use cross-accelerator global softmax so it
  does not reproduce the exact LABSE training.
  """
    config = cfg.ExperimentConfig(
        task=dual_encoder.DualEncoderConfig(
            train_data=dual_encoder_dataloader.DualEncoderDataConfig(),
            validation_data=dual_encoder_dataloader.DualEncoderDataConfig(
                is_training=False, drop_remainder=False)),
        trainer=cfg.TrainerConfig(optimizer_config=LaBSEOptimizationConfig(
            learning_rate=optimization.LrConfig(type="polynomial",
                                                polynomial=PolynomialLr(
                                                    initial_learning_rate=3e-5,
                                                    end_learning_rate=0.0)),
            warmup=optimization.WarmupConfig(
                type="polynomial", polynomial=PolynomialWarmupConfig()))),
        restrictions=[
            "task.train_data.is_training != None",
            "task.validation_data.is_training != None"
        ])
    return config
Пример #2
0
    def test_load_tfds(self, use_preprocessing_hub):
        seq_length = 16
        batch_size = 10
        if use_preprocessing_hub:
            vocab_path = ''
            preprocessing_hub = (
                'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3')
        else:
            vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
            _make_vocab_file(
                ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'],
                vocab_path)
            preprocessing_hub = ''

        data_config = dual_encoder_dataloader.DualEncoderDataConfig(
            tfds_name='para_crawl/enmt',
            tfds_split='train',
            seq_length=seq_length,
            vocab_file=vocab_path,
            lower_case=True,
            left_text_fields=('en', ),
            right_text_fields=('mt', ),
            preprocessing_hub_module_url=preprocessing_hub,
            global_batch_size=batch_size)
        dataset = dual_encoder_dataloader.DualEncoderDataLoader(
            data_config).load()
        features = next(iter(dataset))
        self.assertCountEqual([
            'left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
            'right_mask', 'right_type_ids'
        ], features.keys())
        self.assertEqual(features['left_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
        self.assertEqual(features['left_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['right_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['right_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['right_type_ids'].shape,
                         (batch_size, seq_length))
Пример #3
0
    def test_load_dataset(self):
        seq_length = 16
        batch_size = 10
        train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')

        _create_fake_dataset(train_data_path)
        _make_vocab_file(
            ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'],
            vocab_path)

        data_config = dual_encoder_dataloader.DualEncoderDataConfig(
            input_path=train_data_path,
            seq_length=seq_length,
            vocab_file=vocab_path,
            lower_case=True,
            left_text_fields=(_LEFT_FEATURE_NAME, ),
            right_text_fields=(_RIGHT_FEATURE_NAME, ),
            global_batch_size=batch_size)
        dataset = dual_encoder_dataloader.DualEncoderDataLoader(
            data_config).load()
        features = next(iter(dataset))
        self.assertCountEqual([
            'left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
            'right_mask', 'right_type_ids'
        ], features.keys())
        self.assertEqual(features['left_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
        self.assertEqual(features['left_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['right_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['right_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['right_type_ids'].shape,
                         (batch_size, seq_length))
Пример #4
0
 def setUp(self):
   super(DualEncoderTaskTest, self).setUp()
   self._train_data_config = (
       dual_encoder_dataloader.DualEncoderDataConfig(
           input_path="dummy", seq_length=32))