def test_model_fn_smoke(self, use_tpu, add_critic, encoder_type, encoder_pretraining, train_phase_subset): # This test a train step, running both input_fn and model_fn code paths, # including backward pass. if encoder_pretraining: FLAGS.in_domain_pretrain_steps = 10 else: FLAGS.in_domain_pretrain_steps = 0 FLAGS.pretrain_as_autoencoder = encoder_pretraining == 'autoencoder' if encoder_pretraining == 'nsp': self.params.update({ 'nsp_pretrain': True, 'lambda_nsp_pretrain': 1.0, }) elif encoder_pretraining == 'cpp': self.params.update({ 'cpp_pretrain_scheme': 'last_two', 'lambda_cpp_pretrain': 1.0, }) self.params.update({ 'add_critic': add_critic, 'train_phase_subset': train_phase_subset }) FLAGS.use_tpu = use_tpu tf.reset_default_graph() # Just test it doesn't crash sptokens = [pors.BOS, pors.BOP, pors.MASK] tk, spid_dict = util.get_tokenizer_with_special( os.path.join(self.data_dir, 'wikitext103_32768.subword_vocab'), sptokens) self.params.update({ 'vocab_size': tk.vocab_size, 'embedding_size': 4, 'trf_hidden_size': 4, 'trf_num_heads': 2, 'max_decode_steps': 2, 'encoder_type': encoder_type, }) run_config = tf.contrib.tpu.RunConfig( model_dir=self.create_tempdir().full_path, keep_checkpoint_max=10) pors_estimator = tf.contrib.tpu.TPUEstimator( use_tpu=use_tpu, config=run_config, model_fn=pors.get_model_fn(spid_dict), train_batch_size=4, eval_batch_size=4, predict_batch_size=4, params=self.params) files = util.file_list(self.data_dir, 'valid') pors_estimator.train(input_fn=pors.get_input_fn( self.params, files, True), max_steps=2)
def test_input_fn(self, use_tpu): files = util.file_list(self.data_dir, 'valid') FLAGS.use_tpu = use_tpu input_fn = pors.get_input_fn(self.params, files, False, shuffle=False) dataset = input_fn({'batch_size': 2}) it = dataset.make_one_shot_iterator() next_batch = it.get_next() with self.session() as ss: batch = ss.run(next_batch) self.assertEqual(2, batch[0]['sentences'].shape[0])