예제 #1
0
 def test_nhnet_train_forward(self, distribution):
     seq_length = 10
     # Defines the model inside distribution strategy scope.
     with distribution.scope():
         # Forward path.
         batch_size = 2
         num_docs = 2
         batches = 4
         fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
                             dtype=np.int32)
         fake_inputs = {
             "input_ids":
             fake_ids,
             "input_mask":
             fake_ids,
             "segment_ids":
             fake_ids,
             "target_ids":
             np.zeros((batch_size * batches, seq_length * 2),
                      dtype=np.int32),
         }
         model = models.create_nhnet_model(params=self._nhnet_config)
         results = distribution_forward_path(distribution, model,
                                             fake_inputs, batch_size)
         logging.info("Forward path results: %s", str(results))
         self.assertLen(results, batches)
 def test_nhnet_eval(self, distribution):
   seq_length = 10
   padded_decode = isinstance(
       distribution,
       (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
   self._nhnet_config.override(
       {
           "beam_size": 4,
           "len_title": seq_length,
           "alpha": 0.6,
           "multi_channel_cross_attention": True,
           "padded_decode": padded_decode,
       },
       is_strict=False)
   # Defines the model inside distribution strategy scope.
   with distribution.scope():
     # Forward path.
     batch_size = 2
     num_docs = 2
     batches = 4
     fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
                         dtype=np.int32)
     fake_inputs = {
         "input_ids": fake_ids,
         "input_mask": fake_ids,
         "segment_ids": fake_ids,
         "target_ids": np.zeros((batch_size * batches, 5), dtype=np.int32),
     }
     model = models.create_nhnet_model(params=self._nhnet_config)
     results = distribution_forward_path(
         distribution, model, fake_inputs, batch_size, mode="predict")
     self.assertLen(results, batches)
     results = distribution_forward_path(
         distribution, model, fake_inputs, batch_size, mode="eval")
     self.assertLen(results, batches)
예제 #3
0
 def test_checkpoint_restore(self):
     bert2bert_model = models.create_bert2bert_model(self._bert2bert_config)
     ckpt = tf.train.Checkpoint(model=bert2bert_model)
     init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
     nhnet_model = models.create_nhnet_model(
         params=self._nhnet_config, init_checkpoint=init_checkpoint)
     source_weights = (bert2bert_model.bert_layer.trainable_weights +
                       bert2bert_model.decoder_layer.trainable_weights)
     dest_weights = (nhnet_model.bert_layer.trainable_weights +
                     nhnet_model.decoder_layer.trainable_weights)
     for source_weight, dest_weight in zip(source_weights, dest_weights):
         self.assertAllClose(source_weight.numpy(), dest_weight.numpy())