def test_nhnet_train_forward(self, distribution): seq_length = 10 # Defines the model inside distribution strategy scope. with distribution.scope(): # Forward path. batch_size = 2 num_docs = 2 batches = 4 fake_ids = np.zeros((batch_size * batches, num_docs, seq_length), dtype=np.int32) fake_inputs = { "input_ids": fake_ids, "input_mask": fake_ids, "segment_ids": fake_ids, "target_ids": np.zeros((batch_size * batches, seq_length * 2), dtype=np.int32), } model = models.create_nhnet_model(params=self._nhnet_config) results = distribution_forward_path(distribution, model, fake_inputs, batch_size) logging.info("Forward path results: %s", str(results)) self.assertLen(results, batches)
def test_nhnet_eval(self, distribution): seq_length = 10 padded_decode = isinstance( distribution, (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)) self._nhnet_config.override( { "beam_size": 4, "len_title": seq_length, "alpha": 0.6, "multi_channel_cross_attention": True, "padded_decode": padded_decode, }, is_strict=False) # Defines the model inside distribution strategy scope. with distribution.scope(): # Forward path. batch_size = 2 num_docs = 2 batches = 4 fake_ids = np.zeros((batch_size * batches, num_docs, seq_length), dtype=np.int32) fake_inputs = { "input_ids": fake_ids, "input_mask": fake_ids, "segment_ids": fake_ids, "target_ids": np.zeros((batch_size * batches, 5), dtype=np.int32), } model = models.create_nhnet_model(params=self._nhnet_config) results = distribution_forward_path( distribution, model, fake_inputs, batch_size, mode="predict") self.assertLen(results, batches) results = distribution_forward_path( distribution, model, fake_inputs, batch_size, mode="eval") self.assertLen(results, batches)
def test_checkpoint_restore(self): bert2bert_model = models.create_bert2bert_model(self._bert2bert_config) ckpt = tf.train.Checkpoint(model=bert2bert_model) init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt")) nhnet_model = models.create_nhnet_model( params=self._nhnet_config, init_checkpoint=init_checkpoint) source_weights = ( bert2bert_model.bert_layer.trainable_weights + bert2bert_model.decoder_layer.trainable_weights) dest_weights = ( nhnet_model.bert_layer.trainable_weights + nhnet_model.decoder_layer.trainable_weights) for source_weight, dest_weight in zip(source_weights, dest_weights): self.assertAllClose(source_weight.numpy(), dest_weight.numpy())