def test_model_fn_smoke(self, use_tpu, add_critic, encoder_type, encoder_pretraining, train_phase_subset): # This test a train step, running both input_fn and model_fn code paths, # including backward pass. if encoder_pretraining: FLAGS.in_domain_pretrain_steps = 10 else: FLAGS.in_domain_pretrain_steps = 0 FLAGS.pretrain_as_autoencoder = encoder_pretraining == 'autoencoder' if encoder_pretraining == 'nsp': self.params.update({ 'nsp_pretrain': True, 'lambda_nsp_pretrain': 1.0, }) elif encoder_pretraining == 'cpp': self.params.update({ 'cpp_pretrain_scheme': 'last_two', 'lambda_cpp_pretrain': 1.0, }) self.params.update({ 'add_critic': add_critic, 'train_phase_subset': train_phase_subset }) FLAGS.use_tpu = use_tpu tf.reset_default_graph() # Just test it doesn't crash sptokens = [pors.BOS, pors.BOP, pors.MASK] tk, spid_dict = util.get_tokenizer_with_special( os.path.join(self.data_dir, 'wikitext103_32768.subword_vocab'), sptokens) self.params.update({ 'vocab_size': tk.vocab_size, 'embedding_size': 4, 'trf_hidden_size': 4, 'trf_num_heads': 2, 'max_decode_steps': 2, 'encoder_type': encoder_type, }) run_config = tf.contrib.tpu.RunConfig( model_dir=self.create_tempdir().full_path, keep_checkpoint_max=10) pors_estimator = tf.contrib.tpu.TPUEstimator( use_tpu=use_tpu, config=run_config, model_fn=pors.get_model_fn(spid_dict), train_batch_size=4, eval_batch_size=4, predict_batch_size=4, params=self.params) files = util.file_list(self.data_dir, 'valid') pors_estimator.train(input_fn=pors.get_input_fn( self.params, files, True), max_steps=2)
def test_get_tokenizer_with_special(self): tk_original = util.get_tokenizer(self.vocab) extra_tokens = ['<SPECIAL1>', '<SPECIAL2>'] tk_with_special, sids = util.get_tokenizer_with_special( self.vocab, extra_tokens) o_size = tk_original.vocab_size self.assertEqual(o_size + 2, tk_with_special.vocab_size) self.assertEqual(['<SPECIAL1>_', '<SPECIAL2>_'], tk_with_special.decode_list([o_size, o_size + 1])) self.assertEqual(o_size, sids['<SPECIAL1>']) self.assertEqual(o_size + 1, sids['<SPECIAL2>'])
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.io.gfile.mkdir(FLAGS.output_dir) data_file = os.path.join( FLAGS.data_dir, 'rocstories_gt.' + six.ensure_str(FLAGS.eval_subset) + '.tfrecord') seq_ex_list = util.get_seq_exs(data_file) print('Input data %s' % data_file) # Human summary baselines. # We have 3 human summaries for each example, and # 2 human performance variants: # 1. 'a': average pairwise rouge between two summaries # 2. 'm': maximum pairwise rouge between any two summaries agg_human = {} nwords_human = {} for h in ['a', 'm']: agg_human[h] = scoring.BootstrapAggregator() nwords_human[h] = [] # Extractive baselines # 1. '1','2','3','4','5': rouge between ith sentence and human summary # 2. 'o': for each example, choose sentence with maximum average rouge agg_extract = {} nwords_extract = {} for e in [str(x) for x in list(range(5))] + ['o']: agg_extract[e] = scoring.BootstrapAggregator() nwords_extract[e] = [] # human performance sent2oracle = {} for ex in seq_ex_list: summ_list = p2s_eval.get_summaries(ex) summ_list = [x.decode('utf-8') for x in summ_list] # human eval score, nwords = human_ave(summ_list) agg_human['a'].add_scores(score) nwords_human['a'].append(nwords) score, nwords = human_max(summ_list) agg_human['m'].add_scores(score) nwords_human['m'].append(nwords) # extractive eval extract_list = get_extracts(ex) extract_list = [x.decode('utf-8') for x in extract_list] for e_id, e in enumerate(extract_list): score, nwords = extract_ave(e, summ_list) agg_extract[str(e_id)].add_scores(score) nwords_extract[str(e_id)].append(nwords) score, nwords, e_o = extract_oracle(extract_list, summ_list) agg_extract['o'].add_scores(score) nwords_extract['o'].append(nwords) # save story and oracle sentence for future use first = p2s_eval.get_first_sentence(ex) if first in sent2oracle: logging.fatal('duplicate first sentence: %s', str(first)) sent2oracle[first] = (' '.join(extract_list), e_o) # (story, oracle) # write each example and the corresponding oracle to disk tk, _ = util.get_tokenizer_with_special(FLAGS.vocab_file, []) def detok(s): return tk.decode(util.strip_after_eos(s)) keys_sorted = sorted(sent2oracle.keys(), key=detok) out_file = os.path.join( FLAGS.output_dir, 'rocstories_gt.' + six.ensure_str(FLAGS.eval_subset) + '.firstsent2oracle.txt') with tf.gfile.Open(out_file, 'w') as f: for k in keys_sorted: f.write('%s\n' % (sent2oracle[k][1])) # print out rouge scores for human performance print_agg_score('human average', agg_human['a'], nwords_human['a']) print_agg_score('human max', agg_human['m'], nwords_human['m']) for e_id in range(5): print_agg_score('extractive baseline{}'.format(e_id), agg_extract[str(e_id)], nwords_extract[str(e_id)]) print_agg_score('extractive oracle', agg_extract['o'], nwords_extract['o'])