def test_model_fn_smoke(self, use_tpu, add_critic, encoder_type,
                            encoder_pretraining, train_phase_subset):
        # This test a train step, running both input_fn and model_fn code paths,
        # including backward pass.
        if encoder_pretraining:
            FLAGS.in_domain_pretrain_steps = 10
        else:
            FLAGS.in_domain_pretrain_steps = 0

        FLAGS.pretrain_as_autoencoder = encoder_pretraining == 'autoencoder'

        if encoder_pretraining == 'nsp':
            self.params.update({
                'nsp_pretrain': True,
                'lambda_nsp_pretrain': 1.0,
            })
        elif encoder_pretraining == 'cpp':
            self.params.update({
                'cpp_pretrain_scheme': 'last_two',
                'lambda_cpp_pretrain': 1.0,
            })

        self.params.update({
            'add_critic': add_critic,
            'train_phase_subset': train_phase_subset
        })
        FLAGS.use_tpu = use_tpu
        tf.reset_default_graph()
        # Just test it doesn't crash
        sptokens = [pors.BOS, pors.BOP, pors.MASK]
        tk, spid_dict = util.get_tokenizer_with_special(
            os.path.join(self.data_dir, 'wikitext103_32768.subword_vocab'),
            sptokens)
        self.params.update({
            'vocab_size': tk.vocab_size,
            'embedding_size': 4,
            'trf_hidden_size': 4,
            'trf_num_heads': 2,
            'max_decode_steps': 2,
            'encoder_type': encoder_type,
        })
        run_config = tf.contrib.tpu.RunConfig(
            model_dir=self.create_tempdir().full_path, keep_checkpoint_max=10)

        pors_estimator = tf.contrib.tpu.TPUEstimator(
            use_tpu=use_tpu,
            config=run_config,
            model_fn=pors.get_model_fn(spid_dict),
            train_batch_size=4,
            eval_batch_size=4,
            predict_batch_size=4,
            params=self.params)

        files = util.file_list(self.data_dir, 'valid')
        pors_estimator.train(input_fn=pors.get_input_fn(
            self.params, files, True),
                             max_steps=2)
Beispiel #2
0
  def test_input_fn(self, use_tpu):
    files = util.file_list(self.data_dir, 'valid')
    FLAGS.use_tpu = use_tpu
    input_fn = pors.get_input_fn(self.params, files, False, shuffle=False)

    dataset = input_fn({'batch_size': 2})
    it = dataset.make_one_shot_iterator()
    next_batch = it.get_next()
    with self.session() as ss:
      batch = ss.run(next_batch)
      self.assertEqual(2, batch[0]['sentences'].shape[0])