Exemplo n.º 1
0
    def build_inputs(self, params, input_context=None):
        """Returns tf.data.Dataset for pretraining."""
        if params.input_path == 'dummy':

            def dummy_data(_):
                dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
                dummy_lm = tf.zeros((1, params.max_predictions_per_seq),
                                    dtype=tf.int32)
                return dict(input_word_ids=dummy_ids,
                            input_mask=dummy_ids,
                            input_type_ids=dummy_ids,
                            masked_lm_positions=dummy_lm,
                            masked_lm_ids=dummy_lm,
                            masked_lm_weights=tf.cast(dummy_lm,
                                                      dtype=tf.float32),
                            next_sentence_labels=tf.zeros((1, 1),
                                                          dtype=tf.int32))

            dataset = tf.data.Dataset.range(1)
            dataset = dataset.repeat()
            dataset = dataset.map(
                dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            return dataset

        return pretrain_dataloader.BertPretrainDataLoader(params).load(
            input_context)
Exemplo n.º 2
0
    def test_load_data(self, use_next_sentence_label, use_position_id):
        train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
        seq_length = 128
        max_predictions_per_seq = 20
        _create_fake_dataset(train_data_path,
                             seq_length,
                             max_predictions_per_seq,
                             use_next_sentence_label=use_next_sentence_label,
                             use_position_id=use_position_id)
        data_config = pretrain_dataloader.BertPretrainDataConfig(
            input_path=train_data_path,
            max_predictions_per_seq=max_predictions_per_seq,
            seq_length=seq_length,
            global_batch_size=10,
            is_training=True,
            use_next_sentence_label=use_next_sentence_label,
            use_position_id=use_position_id)

        dataset = pretrain_dataloader.BertPretrainDataLoader(
            data_config).load()
        features = next(iter(dataset))
        self.assertLen(features,
                       6 + int(use_next_sentence_label) + int(use_position_id))
        self.assertIn("input_word_ids", features)
        self.assertIn("input_mask", features)
        self.assertIn("input_type_ids", features)
        self.assertIn("masked_lm_positions", features)
        self.assertIn("masked_lm_ids", features)
        self.assertIn("masked_lm_weights", features)

        self.assertEqual("next_sentence_labels" in features,
                         use_next_sentence_label)
        self.assertEqual("position_ids" in features, use_position_id)
  def test_v2_feature_names(self):
    train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
    seq_length = 128
    max_predictions_per_seq = 20
    _create_fake_bert_dataset(
        train_data_path,
        seq_length,
        max_predictions_per_seq,
        use_next_sentence_label=True,
        use_position_id=False,
        use_v2_feature_names=True)
    data_config = pretrain_dataloader.BertPretrainDataConfig(
        input_path=train_data_path,
        max_predictions_per_seq=max_predictions_per_seq,
        seq_length=seq_length,
        global_batch_size=10,
        is_training=True,
        use_next_sentence_label=True,
        use_position_id=False,
        use_v2_feature_names=True)

    dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
    features = next(iter(dataset))
    self.assertIn("input_word_ids", features)
    self.assertIn("input_mask", features)
    self.assertIn("input_type_ids", features)
    self.assertIn("masked_lm_positions", features)
    self.assertIn("masked_lm_ids", features)
    self.assertIn("masked_lm_weights", features)
    def test_distribution_strategy(self, distribution_strategy):
        max_seq_length = 128
        batch_size = 8
        input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        _create_fake_dataset(input_path,
                             seq_length=60,
                             num_masked_tokens=20,
                             max_seq_length=max_seq_length,
                             num_examples=batch_size)
        data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
            is_training=False,
            input_path=input_path,
            seq_bucket_lengths=[64, 128],
            global_batch_size=batch_size)
        dataloader = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
            data_config)
        distributed_ds = orbit.utils.make_distributed_dataset(
            distribution_strategy, dataloader.load)
        train_iter = iter(distributed_ds)
        with distribution_strategy.scope():
            config = masked_lm.MaskedLMConfig(
                init_checkpoint=self.get_temp_dir(),
                model=bert.PretrainerConfig(
                    encoders.EncoderConfig(bert=encoders.BertEncoderConfig(
                        vocab_size=30522, num_layers=1)),
                    cls_heads=[
                        bert.ClsHeadConfig(inner_dim=10,
                                           num_classes=2,
                                           name='next_sentence')
                    ]),
                train_data=data_config)
            task = masked_lm.MaskedLMTask(config)
            model = task.build_model()
            metrics = task.build_metrics()

        @tf.function
        def step_fn(features):
            return task.validation_step(features, model, metrics=metrics)

        distributed_outputs = distribution_strategy.run(
            step_fn, args=(next(train_iter), ))
        local_results = tf.nest.map_structure(
            distribution_strategy.experimental_local_results,
            distributed_outputs)
        logging.info('Dynamic padding:  local_results= %s', str(local_results))
        dynamic_metrics = {}
        for metric in metrics:
            dynamic_metrics[metric.name] = metric.result()

        data_config = pretrain_dataloader.BertPretrainDataConfig(
            is_training=False,
            input_path=input_path,
            seq_length=max_seq_length,
            max_predictions_per_seq=20,
            global_batch_size=batch_size)
        dataloader = pretrain_dataloader.BertPretrainDataLoader(data_config)
        distributed_ds = orbit.utils.make_distributed_dataset(
            distribution_strategy, dataloader.load)
        train_iter = iter(distributed_ds)
        with distribution_strategy.scope():
            metrics = task.build_metrics()

        @tf.function
        def step_fn_b(features):
            return task.validation_step(features, model, metrics=metrics)

        distributed_outputs = distribution_strategy.run(
            step_fn_b, args=(next(train_iter), ))
        local_results = tf.nest.map_structure(
            distribution_strategy.experimental_local_results,
            distributed_outputs)
        logging.info('Static padding:  local_results= %s', str(local_results))
        static_metrics = {}
        for metric in metrics:
            static_metrics[metric.name] = metric.result()
        for key in static_metrics:
            # We need to investigate the differences on losses.
            if key != 'next_sentence_loss':
                self.assertEqual(dynamic_metrics[key], static_metrics[key])