예제 #1
0
 def test_nhnet_train_forward(self, distribution):
   seq_length = 10
   # Defines the model inside distribution strategy scope.
   with distribution.scope():
     # Forward path.
     batch_size = 2
     num_docs = 2
     batches = 4
     fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
                         dtype=np.int32)
     fake_inputs = {
         "input_ids":
             fake_ids,
         "input_mask":
             fake_ids,
         "segment_ids":
             fake_ids,
         "target_ids":
             np.zeros((batch_size * batches, seq_length * 2), dtype=np.int32),
     }
     model = models.create_nhnet_model(params=self._nhnet_config)
     results = distribution_forward_path(distribution, model, fake_inputs,
                                         batch_size)
     logging.info("Forward path results: %s", str(results))
     self.assertLen(results, batches)
예제 #2
0
 def test_nhnet_eval(self, distribution):
   seq_length = 10
   padded_decode = isinstance(
       distribution,
       (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
   self._nhnet_config.override(
       {
           "beam_size": 4,
           "len_title": seq_length,
           "alpha": 0.6,
           "multi_channel_cross_attention": True,
           "padded_decode": padded_decode,
       },
       is_strict=False)
   # Defines the model inside distribution strategy scope.
   with distribution.scope():
     # Forward path.
     batch_size = 2
     num_docs = 2
     batches = 4
     fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
                         dtype=np.int32)
     fake_inputs = {
         "input_ids": fake_ids,
         "input_mask": fake_ids,
         "segment_ids": fake_ids,
         "target_ids": np.zeros((batch_size * batches, 5), dtype=np.int32),
     }
     model = models.create_nhnet_model(params=self._nhnet_config)
     results = distribution_forward_path(
         distribution, model, fake_inputs, batch_size, mode="predict")
     self.assertLen(results, batches)
     results = distribution_forward_path(
         distribution, model, fake_inputs, batch_size, mode="eval")
     self.assertLen(results, batches)
예제 #3
0
 def test_checkpoint_restore(self):
   bert2bert_model = models.create_bert2bert_model(self._bert2bert_config)
   ckpt = tf.train.Checkpoint(model=bert2bert_model)
   init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
   nhnet_model = models.create_nhnet_model(
       params=self._nhnet_config, init_checkpoint=init_checkpoint)
   source_weights = (
       bert2bert_model.bert_layer.trainable_weights +
       bert2bert_model.decoder_layer.trainable_weights)
   dest_weights = (
       nhnet_model.bert_layer.trainable_weights +
       nhnet_model.decoder_layer.trainable_weights)
   for source_weight, dest_weight in zip(source_weights, dest_weights):
     self.assertAllClose(source_weight.numpy(), dest_weight.numpy())