Пример #1
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.singles_and_pairs == 'both':
        FLAGS.exp_name = FLAGS.exp_name + '_both'
        exp_name = _exp_name + '_both'
        dataset_articles = _dataset_articles
    else:
        FLAGS.exp_name = FLAGS.exp_name + '_singles'
        exp_name = _exp_name + '_singles'
        dataset_articles = _dataset_articles + '_singles'
    my_log_dir = os.path.join(log_dir, FLAGS.ssi_exp_name)

    print('Running statistics on %s' % FLAGS.exp_name)

    if FLAGS.dataset_name != "":
        FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name,
                                       FLAGS.dataset_split + '*')
    if not os.path.exists(os.path.join(
            FLAGS.data_root, FLAGS.dataset_name)) or len(
                os.listdir(os.path.join(FLAGS.data_root,
                                        FLAGS.dataset_name))) == 0:
        print(('No TF example data found at %s so creating it from raw data.' %
               os.path.join(FLAGS.data_root, FLAGS.dataset_name)))
        convert_data.process_dataset(FLAGS.dataset_name)

    logging.set_verbosity(
        logging.INFO)  # choose what level of logging you want
    logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name
    FLAGS.actual_log_root = FLAGS.log_root
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode != 'decode':
        raise Exception(
            "The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage',
        'cov_loss_wt', 'pointer_gen', 'lambdamart_input'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val.value  # add it to the dict
    hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict)

    tf.set_random_seed(113)  # a seed value for randomness

    decode_model_hps = hps._replace(
        max_dec_steps=1
    )  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))

    with open(os.path.join(my_log_dir, 'ssi.pkl')) as f:
        ssi_list = pickle.load(f)

    total = len(source_files
                ) * 1000 if 'cnn' or 'newsroom' in dataset_articles else len(
                    source_files)
    example_generator = data.example_generator(source_dir + '/' +
                                               dataset_split + '*',
                                               True,
                                               False,
                                               should_check_valid=False)
    # batcher = Batcher(None, vocab, hps, single_pass=FLAGS.single_pass)
    model = SummarizationModel(decode_model_hps, vocab)
    decoder = BeamSearchDecoder(model, None, vocab)
    decoder.decode_iteratively(example_generator, total, names_to_types,
                               ssi_list, hps)

    a = 0
Пример #2
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    extractor = 'bert' if FLAGS.use_bert else 'lambdamart'
    pretrained_dataset = FLAGS.dataset_name
    if FLAGS.dataset_name == 'duc_2004':
        pretrained_dataset = 'cnn_dm'
    if FLAGS.singles_and_pairs == 'both':
        FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_both'
        FLAGS.pretrained_path = os.path.join(FLAGS.log_root,
                                             pretrained_dataset + '_both')
        dataset_articles = FLAGS.dataset_name
    else:
        FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_singles'
        FLAGS.pretrained_path = os.path.join(FLAGS.log_root,
                                             pretrained_dataset + '_singles')
        dataset_articles = FLAGS.dataset_name + '_singles'

    if FLAGS.upper_bound:
        FLAGS.exp_name = FLAGS.exp_name + '_upperbound'
        ssi_list = None  # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
    else:
        my_log_dir = os.path.join(
            log_dir, '%s_%s_%s' %
            (FLAGS.dataset_name, extractor, FLAGS.singles_and_pairs))
        with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f:
            ssi_list = pickle.load(f)

    print('Running statistics on %s' % FLAGS.exp_name)

    if FLAGS.dataset_name != "":
        FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name,
                                       FLAGS.dataset_split + '*')
    if not os.path.exists(os.path.join(
            FLAGS.data_root, FLAGS.dataset_name)) or len(
                os.listdir(os.path.join(FLAGS.data_root,
                                        FLAGS.dataset_name))) == 0:
        raise Exception('No TF example data found at %s.' %
                        os.path.join(FLAGS.data_root, FLAGS.dataset_name))

    logging.set_verbosity(
        logging.INFO)  # choose what level of logging you want
    logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name
    FLAGS.actual_log_root = FLAGS.log_root
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)

    print(util.bcolors.OKGREEN + "Experiment path: " + FLAGS.log_root +
          util.bcolors.ENDC)

    if FLAGS.dataset_name == 'duc_2004':
        vocab = Vocab(FLAGS.vocab_path + '_' + 'cnn_dm',
                      FLAGS.vocab_size)  # create a vocabulary
    else:
        vocab = Vocab(FLAGS.vocab_path + '_' + FLAGS.dataset_name,
                      FLAGS.vocab_size)  # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode != 'decode':
        raise Exception(
            "The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        item for item in list(FLAGS.flag_values_dict().keys()) if item != '?'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val.value  # add it to the dict
    hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict)

    tf.set_random_seed(113)  # a seed value for randomness

    decode_model_hps = hps._replace(
        max_dec_steps=1
    )  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(FLAGS.data_root, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))

    total = len(
        source_files
    ) * 1000 if 'cnn' in dataset_articles or 'xsum' in dataset_articles else len(
        source_files)
    example_generator = data.example_generator(source_dir + '/' +
                                               dataset_split + '*',
                                               True,
                                               False,
                                               should_check_valid=False)
    # batcher = Batcher(None, vocab, hps, single_pass=FLAGS.single_pass)
    model = SummarizationModel(decode_model_hps, vocab)
    decoder = BeamSearchDecoder(model, None, vocab)
    decoder.decode_iteratively(example_generator, total, names_to_types,
                               ssi_list, hps)
Пример #3
0
def main(unused_argv):
    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    extractor = 'bert' if FLAGS.use_bert else 'lambdamart'
    if FLAGS.cnn_dm_pg:
        pretrained_dataset = 'cnn_dm'
    elif FLAGS.websplit:
        pretrained_dataset = 'websplit'
    else:
        pretrained_dataset = FLAGS.dataset_name
    if FLAGS.dataset_name == 'duc_2004':
        pretrained_dataset = 'cnn_dm'
    if FLAGS.singles_and_pairs == 'both':
        FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_both'
        FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_both')
        dataset_articles = FLAGS.dataset_name
    else:
        FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_singles'
        FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_singles')
        dataset_articles = FLAGS.dataset_name + '_singles'
    if FLAGS.word_imp_reg:
        FLAGS.pretrained_path += '_imp' + str(FLAGS.imp_loss_wt)
        FLAGS.exp_name += '_imp' + str(FLAGS.imp_loss_wt)
        if FLAGS.imp_loss_oneminus:
            FLAGS.pretrained_path += '_oneminus'
            FLAGS.exp_name += '_oneminus'
    if FLAGS.sep:
        FLAGS.pretrained_path += '_sep'
        FLAGS.exp_name += '_sep'
    if FLAGS.tag_tokens:
        FLAGS.pretrained_path += '_tag'
        FLAGS.exp_name += '_tag' + str(FLAGS.tag_loss_wt)



    bert_suffix = ''
    # if FLAGS.use_bert:
    #     if FLAGS.sentemb:
    #         FLAGS.exp_name += '_sentemb'
    #         bert_suffix += '_sentemb'
    #     if FLAGS.artemb:
    #         FLAGS.exp_name += '_artemb'
    #         bert_suffix += '_artemb'
    #     if FLAGS.plushidden:
    #         FLAGS.exp_name += '_plushidden'
    #         bert_suffix += '_plushidden'
    if FLAGS.tag_tokens:
        bert_suffix += '_tag' + str(FLAGS.tag_loss_wt)
    else:
        bert_suffix += '_tag' + '0.0'
    if FLAGS.upper_bound:
        FLAGS.exp_name = FLAGS.exp_name + '_upperbound'
        ssi_list = None     # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
    else:
        my_log_dir = os.path.join(log_dir, '%s_%s_%s%s' % (FLAGS.dataset_name, extractor, FLAGS.singles_and_pairs, bert_suffix))
        print(util.bcolors.OKGREEN + "BERT path: " + my_log_dir + util.bcolors.ENDC)
        with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f:
            ssi_list = pickle.load(f)
        FLAGS.ssi_data_path = my_log_dir
    if FLAGS.cnn_dm_pg:
        FLAGS.exp_name = FLAGS.exp_name + '_cnntrained'
    if FLAGS.websplit:
        FLAGS.exp_name = FLAGS.exp_name + '_websplittrained'
    if FLAGS.first_intact:
        FLAGS.exp_name = FLAGS.exp_name + '_firstintact'




    print('Running statistics on %s' % FLAGS.exp_name)

    if FLAGS.dataset_name != "":
        FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*')
    if not os.path.exists(os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len(os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0:
        print(('No TF example data found at %s so creating it from raw data.' % os.path.join(FLAGS.data_root, FLAGS.dataset_name)))
        convert_data.process_dataset(FLAGS.dataset_name)

    logging.set_verbosity(logging.INFO) # choose what level of logging you want
    logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name
    FLAGS.actual_log_root = FLAGS.log_root
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)

    print(util.bcolors.OKGREEN + "Experiment path: " + FLAGS.log_root + util.bcolors.ENDC)


    if FLAGS.dataset_name == 'duc_2004':
        vocab = Vocab(FLAGS.vocab_path + '_' + 'cnn_dm', FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary
    else:
        vocab_datasets = [os.path.basename(file_path).split('vocab_')[1] for file_path in glob.glob(FLAGS.vocab_path + '_*')]
        original_dataset_name = [file_name for file_name in vocab_datasets if file_name in FLAGS.dataset_name]
        if len(original_dataset_name) > 1:
            raise Exception('Too many choices for vocab file')
        if len(original_dataset_name) < 1:
            raise Exception('No vocab file for dataset created. Run make_vocab.py --dataset_name=<my original dataset name>')
        original_dataset_name = original_dataset_name[0]
        FLAGS.original_dataset_name = original_dataset_name
        vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name, FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode!='decode':
        raise Exception("The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    # hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
    #                'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps',
    #                'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'lambdamart_input', 'pg_mmr', 'singles_and_pairs', 'skip_with_less_than_3',
    #                'ssi_data_path', 'word_imp_reg', 'imp_loss_wt']
    hparam_list = [item for item in list(FLAGS.flag_values_dict().keys()) if item != '?']
    hps_dict = {}
    for key,val in FLAGS.__flags.items(): # for each flag
        if key in hparam_list: # if it's in the list
            hps_dict[key] = val.value # add it to the dict
    hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict)

    tf.set_random_seed(113) # a seed value for randomness

    decode_model_hps = hps._replace(
        max_dec_steps=1)  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries

    if len(unused_argv) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(FLAGS.data_root, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))

    total = len(source_files) * 1000 if 'cnn' in dataset_articles or 'xsum' in dataset_articles else len(source_files)
    example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False,
                                               should_check_valid=False)
    # batcher = Batcher(None, vocab, hps, single_pass=FLAGS.single_pass)
    model = SummarizationModel(decode_model_hps, vocab)
    decoder = BeamSearchDecoder(model, None, vocab)
    decoder.decode_iteratively(example_generator, total, names_to_types, ssi_list, hps)

    # num_outside = []
    # for example_idx, example in enumerate(tqdm(example_generator, total=total)):
    #     raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs = util.unpack_tf_example(
    #         example, names_to_types)
    #     article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
    #     cur_token_idx = 0
    #     for sent_idx, sent_tokens in enumerate(article_sent_tokens):
    #         for token in sent_tokens:
    #             cur_token_idx += 1
    #             if cur_token_idx >= 400:
    #                 sent_idx_at_400 = sent_idx
    #                 break
    #         if cur_token_idx >= 400:
    #             break
    #
    #     my_num_outside = 0
    #     for ssi in groundtruth_similar_source_indices_list:
    #         for source_idx in ssi:
    #             if source_idx >= sent_idx_at_400:
    #                 my_num_outside += 1
    #     num_outside.append(my_num_outside)
    # print "num_outside = %d" % np.mean(num_outside)


    a=0