Esempio n. 1
0
    def test_bert2bert_decoding(self):
        seq_length = 10
        self._config.override(
            {
                "beam_size": 3,
                "len_title": seq_length,
                "alpha": 0.6,
            },
            is_strict=False)

        batch_size = 2
        fake_ids = np.zeros((batch_size, seq_length), dtype=np.int32)
        fake_inputs = {
            "input_ids": fake_ids,
            "input_mask": fake_ids,
            "segment_ids": fake_ids,
        }
        self._config.override({
            "padded_decode": False,
            "use_cache": False,
        },
                              is_strict=False)
        model = models.create_bert2bert_model(params=self._config)
        ckpt = tf.train.Checkpoint(model=model)

        # Initializes variables from checkpoint to keep outputs deterministic.
        init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
        ckpt.restore(init_checkpoint).assert_existing_objects_matched()
        top_ids, scores = model(fake_inputs, mode="predict")

        self._config.override({
            "padded_decode": False,
            "use_cache": True,
        },
                              is_strict=False)
        model = models.create_bert2bert_model(params=self._config)
        ckpt = tf.train.Checkpoint(model=model)
        ckpt.restore(init_checkpoint).assert_existing_objects_matched()
        cached_top_ids, cached_scores = model(fake_inputs, mode="predict")
        self.assertEqual(
            process_decoded_ids(top_ids, self._config.end_token_id),
            process_decoded_ids(cached_top_ids, self._config.end_token_id))
        self.assertAllClose(scores, cached_scores)

        self._config.override({
            "padded_decode": True,
            "use_cache": True,
        },
                              is_strict=False)
        model = models.create_bert2bert_model(params=self._config)
        ckpt = tf.train.Checkpoint(model=model)
        ckpt.restore(init_checkpoint).assert_existing_objects_matched()
        padded_top_ids, padded_scores = model(fake_inputs, mode="predict")
        self.assertEqual(
            process_decoded_ids(top_ids, self._config.end_token_id),
            process_decoded_ids(padded_top_ids, self._config.end_token_id))
        self.assertAllClose(scores, padded_scores)
 def test_bert2bert_eval(self, distribution):
   seq_length = 10
   padded_decode = isinstance(
       distribution,
       (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
   self._config.override(
       {
           "beam_size": 3,
           "len_title": seq_length,
           "alpha": 0.6,
           "padded_decode": padded_decode,
       },
       is_strict=False)
   # Defines the model inside distribution strategy scope.
   with distribution.scope():
     # Forward path.
     batch_size = 2
     batches = 4
     fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32)
     fake_inputs = {
         "input_ids": fake_ids,
         "input_mask": fake_ids,
         "segment_ids": fake_ids,
     }
     model = models.create_bert2bert_model(params=self._config)
     results = distribution_forward_path(
         distribution, model, fake_inputs, batch_size, mode="predict")
     self.assertLen(results, batches)
     results = distribution_forward_path(
         distribution, model, fake_inputs, batch_size, mode="eval")
     self.assertLen(results, batches)
Esempio n. 3
0
 def test_model_creation(self):
     model = models.create_bert2bert_model(params=self._config)
     fake_ids = np.zeros((2, 10), dtype=np.int32)
     fake_inputs = {
         "input_ids": fake_ids,
         "input_mask": fake_ids,
         "segment_ids": fake_ids,
         "target_ids": fake_ids,
     }
     model(fake_inputs)
Esempio n. 4
0
 def test_checkpoint_restore(self):
     bert2bert_model = models.create_bert2bert_model(self._bert2bert_config)
     ckpt = tf.train.Checkpoint(model=bert2bert_model)
     init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
     nhnet_model = models.create_nhnet_model(
         params=self._nhnet_config, init_checkpoint=init_checkpoint)
     source_weights = (bert2bert_model.bert_layer.trainable_weights +
                       bert2bert_model.decoder_layer.trainable_weights)
     dest_weights = (nhnet_model.bert_layer.trainable_weights +
                     nhnet_model.decoder_layer.trainable_weights)
     for source_weight, dest_weight in zip(source_weights, dest_weights):
         self.assertAllClose(source_weight.numpy(), dest_weight.numpy())
 def test_bert2bert_train_forward(self, distribution):
   seq_length = 10
   # Defines the model inside distribution strategy scope.
   with distribution.scope():
     # Forward path.
     batch_size = 2
     batches = 4
     fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32)
     fake_inputs = {
         "input_ids": fake_ids,
         "input_mask": fake_ids,
         "segment_ids": fake_ids,
         "target_ids": fake_ids,
     }
     model = models.create_bert2bert_model(params=self._config)
     results = distribution_forward_path(distribution, model, fake_inputs,
                                         batch_size)
     logging.info("Forward path results: %s", str(results))
     self.assertLen(results, batches)