Ejemplo n.º 1
0
    def test_model_fn_smoke(self, use_tpu, add_critic, encoder_type,
                            encoder_pretraining, train_phase_subset):
        # This test a train step, running both input_fn and model_fn code paths,
        # including backward pass.
        if encoder_pretraining:
            FLAGS.in_domain_pretrain_steps = 10
        else:
            FLAGS.in_domain_pretrain_steps = 0

        FLAGS.pretrain_as_autoencoder = encoder_pretraining == 'autoencoder'

        if encoder_pretraining == 'nsp':
            self.params.update({
                'nsp_pretrain': True,
                'lambda_nsp_pretrain': 1.0,
            })
        elif encoder_pretraining == 'cpp':
            self.params.update({
                'cpp_pretrain_scheme': 'last_two',
                'lambda_cpp_pretrain': 1.0,
            })

        self.params.update({
            'add_critic': add_critic,
            'train_phase_subset': train_phase_subset
        })
        FLAGS.use_tpu = use_tpu
        tf.reset_default_graph()
        # Just test it doesn't crash
        sptokens = [pors.BOS, pors.BOP, pors.MASK]
        tk, spid_dict = util.get_tokenizer_with_special(
            os.path.join(self.data_dir, 'wikitext103_32768.subword_vocab'),
            sptokens)
        self.params.update({
            'vocab_size': tk.vocab_size,
            'embedding_size': 4,
            'trf_hidden_size': 4,
            'trf_num_heads': 2,
            'max_decode_steps': 2,
            'encoder_type': encoder_type,
        })
        run_config = tf.contrib.tpu.RunConfig(
            model_dir=self.create_tempdir().full_path, keep_checkpoint_max=10)

        pors_estimator = tf.contrib.tpu.TPUEstimator(
            use_tpu=use_tpu,
            config=run_config,
            model_fn=pors.get_model_fn(spid_dict),
            train_batch_size=4,
            eval_batch_size=4,
            predict_batch_size=4,
            params=self.params)

        files = util.file_list(self.data_dir, 'valid')
        pors_estimator.train(input_fn=pors.get_input_fn(
            self.params, files, True),
                             max_steps=2)
Ejemplo n.º 2
0
 def test_get_tokenizer_with_special(self):
     tk_original = util.get_tokenizer(self.vocab)
     extra_tokens = ['<SPECIAL1>', '<SPECIAL2>']
     tk_with_special, sids = util.get_tokenizer_with_special(
         self.vocab, extra_tokens)
     o_size = tk_original.vocab_size
     self.assertEqual(o_size + 2, tk_with_special.vocab_size)
     self.assertEqual(['<SPECIAL1>_', '<SPECIAL2>_'],
                      tk_with_special.decode_list([o_size, o_size + 1]))
     self.assertEqual(o_size, sids['<SPECIAL1>'])
     self.assertEqual(o_size + 1, sids['<SPECIAL2>'])
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    tf.io.gfile.mkdir(FLAGS.output_dir)

    data_file = os.path.join(
        FLAGS.data_dir,
        'rocstories_gt.' + six.ensure_str(FLAGS.eval_subset) + '.tfrecord')
    seq_ex_list = util.get_seq_exs(data_file)
    print('Input data %s' % data_file)

    # Human summary baselines.
    # We have 3 human summaries for each example, and
    # 2 human performance variants:
    #   1. 'a': average pairwise rouge between two summaries
    #   2. 'm': maximum pairwise rouge between any two summaries
    agg_human = {}
    nwords_human = {}
    for h in ['a', 'm']:
        agg_human[h] = scoring.BootstrapAggregator()
        nwords_human[h] = []

    # Extractive baselines
    #   1. '1','2','3','4','5': rouge between ith sentence and human summary
    #   2. 'o': for each example, choose sentence with maximum average rouge
    agg_extract = {}
    nwords_extract = {}
    for e in [str(x) for x in list(range(5))] + ['o']:
        agg_extract[e] = scoring.BootstrapAggregator()
        nwords_extract[e] = []

    # human performance
    sent2oracle = {}
    for ex in seq_ex_list:
        summ_list = p2s_eval.get_summaries(ex)
        summ_list = [x.decode('utf-8') for x in summ_list]

        # human eval
        score, nwords = human_ave(summ_list)
        agg_human['a'].add_scores(score)
        nwords_human['a'].append(nwords)

        score, nwords = human_max(summ_list)
        agg_human['m'].add_scores(score)
        nwords_human['m'].append(nwords)

        # extractive eval
        extract_list = get_extracts(ex)
        extract_list = [x.decode('utf-8') for x in extract_list]
        for e_id, e in enumerate(extract_list):
            score, nwords = extract_ave(e, summ_list)
            agg_extract[str(e_id)].add_scores(score)
            nwords_extract[str(e_id)].append(nwords)

        score, nwords, e_o = extract_oracle(extract_list, summ_list)
        agg_extract['o'].add_scores(score)
        nwords_extract['o'].append(nwords)

        # save story and oracle sentence for future use
        first = p2s_eval.get_first_sentence(ex)
        if first in sent2oracle:
            logging.fatal('duplicate first sentence: %s', str(first))
        sent2oracle[first] = (' '.join(extract_list), e_o)  # (story, oracle)

    # write each example and the corresponding oracle to disk
    tk, _ = util.get_tokenizer_with_special(FLAGS.vocab_file, [])

    def detok(s):
        return tk.decode(util.strip_after_eos(s))

    keys_sorted = sorted(sent2oracle.keys(), key=detok)

    out_file = os.path.join(
        FLAGS.output_dir, 'rocstories_gt.' +
        six.ensure_str(FLAGS.eval_subset) + '.firstsent2oracle.txt')
    with tf.gfile.Open(out_file, 'w') as f:
        for k in keys_sorted:
            f.write('%s\n' % (sent2oracle[k][1]))

    # print out rouge scores for human performance
    print_agg_score('human average', agg_human['a'], nwords_human['a'])
    print_agg_score('human max', agg_human['m'], nwords_human['m'])
    for e_id in range(5):
        print_agg_score('extractive baseline{}'.format(e_id),
                        agg_extract[str(e_id)], nwords_extract[str(e_id)])
    print_agg_score('extractive oracle', agg_extract['o'], nwords_extract['o'])