def test_bert2bert_decoding(self): seq_length = 10 self._config.override( { "beam_size": 3, "len_title": seq_length, "alpha": 0.6, }, is_strict=False) batch_size = 2 fake_ids = np.zeros((batch_size, seq_length), dtype=np.int32) fake_inputs = { "input_ids": fake_ids, "input_mask": fake_ids, "segment_ids": fake_ids, } self._config.override({ "padded_decode": False, "use_cache": False, }, is_strict=False) model = models.create_bert2bert_model(params=self._config) ckpt = tf.train.Checkpoint(model=model) # Initializes variables from checkpoint to keep outputs deterministic. init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt")) ckpt.restore(init_checkpoint).assert_existing_objects_matched() top_ids, scores = model(fake_inputs, mode="predict") self._config.override({ "padded_decode": False, "use_cache": True, }, is_strict=False) model = models.create_bert2bert_model(params=self._config) ckpt = tf.train.Checkpoint(model=model) ckpt.restore(init_checkpoint).assert_existing_objects_matched() cached_top_ids, cached_scores = model(fake_inputs, mode="predict") self.assertEqual( process_decoded_ids(top_ids, self._config.end_token_id), process_decoded_ids(cached_top_ids, self._config.end_token_id)) self.assertAllClose(scores, cached_scores) self._config.override({ "padded_decode": True, "use_cache": True, }, is_strict=False) model = models.create_bert2bert_model(params=self._config) ckpt = tf.train.Checkpoint(model=model) ckpt.restore(init_checkpoint).assert_existing_objects_matched() padded_top_ids, padded_scores = model(fake_inputs, mode="predict") self.assertEqual( process_decoded_ids(top_ids, self._config.end_token_id), process_decoded_ids(padded_top_ids, self._config.end_token_id)) self.assertAllClose(scores, padded_scores)
def test_bert2bert_eval(self, distribution): seq_length = 10 padded_decode = isinstance( distribution, (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)) self._config.override( { "beam_size": 3, "len_title": seq_length, "alpha": 0.6, "padded_decode": padded_decode, }, is_strict=False) # Defines the model inside distribution strategy scope. with distribution.scope(): # Forward path. batch_size = 2 batches = 4 fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32) fake_inputs = { "input_ids": fake_ids, "input_mask": fake_ids, "segment_ids": fake_ids, } model = models.create_bert2bert_model(params=self._config) results = distribution_forward_path( distribution, model, fake_inputs, batch_size, mode="predict") self.assertLen(results, batches) results = distribution_forward_path( distribution, model, fake_inputs, batch_size, mode="eval") self.assertLen(results, batches)
def test_model_creation(self): model = models.create_bert2bert_model(params=self._config) fake_ids = np.zeros((2, 10), dtype=np.int32) fake_inputs = { "input_ids": fake_ids, "input_mask": fake_ids, "segment_ids": fake_ids, "target_ids": fake_ids, } model(fake_inputs)
def test_checkpoint_restore(self): bert2bert_model = models.create_bert2bert_model(self._bert2bert_config) ckpt = tf.train.Checkpoint(model=bert2bert_model) init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt")) nhnet_model = models.create_nhnet_model( params=self._nhnet_config, init_checkpoint=init_checkpoint) source_weights = (bert2bert_model.bert_layer.trainable_weights + bert2bert_model.decoder_layer.trainable_weights) dest_weights = (nhnet_model.bert_layer.trainable_weights + nhnet_model.decoder_layer.trainable_weights) for source_weight, dest_weight in zip(source_weights, dest_weights): self.assertAllClose(source_weight.numpy(), dest_weight.numpy())
def test_bert2bert_train_forward(self, distribution): seq_length = 10 # Defines the model inside distribution strategy scope. with distribution.scope(): # Forward path. batch_size = 2 batches = 4 fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32) fake_inputs = { "input_ids": fake_ids, "input_mask": fake_ids, "segment_ids": fake_ids, "target_ids": fake_ids, } model = models.create_bert2bert_model(params=self._config) results = distribution_forward_path(distribution, model, fake_inputs, batch_size) logging.info("Forward path results: %s", str(results)) self.assertLen(results, batches)