예제 #1
0
def main(unused_argv):
    if len(
            unused_argv) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    if FLAGS.dataset_name != "":
        FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name,
                                       FLAGS.dataset_split + '*')
    if not os.path.exists(
            os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len(
        os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0:
        print(
                'No TF example data found at %s so creating it from raw data.' % os.path.join(
            FLAGS.data_root, FLAGS.dataset_name))
        convert_data.process_dataset(FLAGS.dataset_name)

    logging.set_verbosity(logging.INFO)  # choose what level of logging you want
    logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name
    FLAGS.actual_log_root = FLAGS.log_root
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode != 'decode':
        raise Exception(
            "The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
                   'trunc_norm_init_std',
                   'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size',
                   'max_dec_steps',
                   'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen']
    hps_dict = {}
    for key, val in FLAGS.__flags.iteritems():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val.value  # add it to the dict
    hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    if FLAGS.pg_mmr or FLAGS.pg_mmr_sim or FLAGS.pg_mmr_diff:

        # Fit the TFIDF vectorizer if not already fitted
        if FLAGS.importance_fn == 'tfidf':
            tfidf_model_path = os.path.join(FLAGS.actual_log_root,
                                            'tfidf_vectorizer',
                                            FLAGS.dataset_name + '.dill')
            if not os.path.exists(tfidf_model_path):
                print(
                        'No TFIDF vectorizer model file found at %s, so fitting the model now.' % tfidf_model_path)
                tfidf_vectorizer = fit_tfidf_vectorizer(hps, vocab)
                with open(tfidf_model_path, 'wb') as f:
                    dill.dump(tfidf_vectorizer, f)

        # Train the SVR model on the CNN validation set if not already trained
        if FLAGS.importance_fn == 'svr':
            save_path = os.path.join(FLAGS.data_root, 'svr_training_data')
            importance_model_path = os.path.join(FLAGS.actual_log_root,
                                                 'svr.pickle')
            dataset_split = 'val'
            if not os.path.exists(importance_model_path):
                if not os.path.exists(save_path) or len(
                        os.listdir(save_path)) == 0:
                    print(
                            'No importance_feature instances found at %s so creating it from raw data.' % save_path)
                    decode_model_hps = hps._replace(
                        max_dec_steps=1, batch_size=100,
                        mode='calc_features')  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
                    cnn_dm_train_data_path = os.path.join(FLAGS.data_root,
                                                          FLAGS.dataset_name,
                                                          dataset_split + '*')
                    batcher = Batcher(cnn_dm_train_data_path, vocab,
                                      decode_model_hps,
                                      single_pass=FLAGS.single_pass,
                                      cnn_500_dm_500=False)
                    calc_features(cnn_dm_train_data_path, decode_model_hps,
                                  vocab, batcher, save_path)

                print(
                        'No importance_feature SVR model found at %s so training it now.' % importance_model_path)
                features_list = importance_features.get_features_list(True)
                sent_reps = importance_features.load_data(
                    os.path.join(save_path, dataset_split + '*'), -1)
                print 'Loaded %d sentences representations' % len(sent_reps)
                x_y = importance_features.features_to_array(sent_reps,
                                                            features_list)
                train_x, train_y = x_y[:, :-1], x_y[:, -1]
                svr_model = importance_features.run_training(train_x, train_y)
                with open(importance_model_path, 'wb') as f:
                    cPickle.dump(svr_model, f)

    # Create a batcher object that will create minibatches of data
    batcher = Batcher(FLAGS.data_path, vocab, hps,
                      single_pass=FLAGS.single_pass)

    tf.set_random_seed(111)  # a seed value for randomness

    # Start decoding on multi-document inputs
    if hps.mode == 'decode':
        decode_model_hps = hps._replace(
            max_dec_steps=1)  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
        model = SummarizationModel(decode_model_hps, vocab)
        decoder = BeamSearchDecoder(model, batcher, vocab)
        decoder.decode()  # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
    else:
        raise ValueError("The 'mode' flag must be one of train/eval/decode")
예제 #2
0
def main(unused_argv):
    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    # if '_sent' in FLAGS.dataset_name:
    #     FLAGS.data_root = os.path.expanduser('~') + '/data/tf_data/with_coref_and_tag_tokens'
    if FLAGS.pg_mmr:
        FLAGS.data_root = os.path.expanduser('~') + "/data/tf_data/with_coref_and_ssi"
    if FLAGS.dataset_name != "":
        FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*')
    if FLAGS.dataset_name in kaiqiang_dataset_names:
        FLAGS.skip_with_less_than_3 = False
    if not os.path.exists(os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len(os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0:
        print(('No TF example data found at %s so creating it from raw data.' % os.path.join(FLAGS.data_root, FLAGS.dataset_name)))
        convert_data.process_dataset(FLAGS.dataset_name)

    if FLAGS.mode == 'decode':
        extractor = '_bert' if FLAGS.use_bert else '_lambdamart'
        FLAGS.use_pretrained = True
        FLAGS.single_pass = True
    else:
        extractor = ''
    pretrained_dataset = FLAGS.dataset_name
    if FLAGS.dataset_name == 'duc_2004':
        pretrained_dataset = 'cnn_dm'
    if FLAGS.pg_mmr:
        FLAGS.exp_name += '_pgmmr'
    if FLAGS.singles_and_pairs == 'both':
        FLAGS.exp_name = FLAGS.exp_name + extractor + '_both'
        if FLAGS.mode == 'decode':
            FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_pgmmr_both')
        dataset_articles = FLAGS.dataset_name
    elif FLAGS.singles_and_pairs == 'singles':
        FLAGS.exp_name = FLAGS.exp_name + extractor + '_singles'
        if FLAGS.mode == 'decode':
            FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_pgmmr_singles')
        dataset_articles = FLAGS.dataset_name + '_singles'

    if FLAGS.notrain:
        FLAGS.exp_name += '_notrain'
        FLAGS.pretrained_path = original_pretrained_path[FLAGS.dataset_name]
    if FLAGS.finetune:
        FLAGS.exp_name += '_finetune'
        if FLAGS.mode == 'decode':
            FLAGS.pretrained_path += '_finetune'
    if FLAGS.sep:
        FLAGS.exp_name += '_sep'
    if FLAGS.tag_tokens:
        FLAGS.exp_name += '_tag'

    extractor = 'bert' if FLAGS.use_bert else 'lambdamart'
    bert_suffix = ''
    # if FLAGS.use_bert:
    #     if FLAGS.sentemb:
    #         bert_suffix += '_sentemb'
    #     if FLAGS.artemb:
    #         bert_suffix += '_artemb'
    #     if FLAGS.plushidden:
    #         bert_suffix += '_plushidden'
        # if FLAGS.mode == 'decode':
        #     if FLAGS.sentemb:
        #         FLAGS.exp_name += '_sentemb'
        #     if FLAGS.artemb:
        #         FLAGS.exp_name += '_artemb'
        #     if FLAGS.plushidden:
        #         FLAGS.exp_name += '_plushidden'
    if FLAGS.upper_bound:
        FLAGS.exp_name = FLAGS.exp_name + '_upperbound'
        ssi_list = None     # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
    else:
        if FLAGS.mode == 'decode':
            my_log_dir = os.path.join(log_dir, '%s_%s_%s%s' % (FLAGS.dataset_name, extractor, FLAGS.singles_and_pairs, bert_suffix))
            FLAGS.ssi_data_path = my_log_dir

    logging.set_verbosity(logging.INFO) # choose what level of logging you want
    logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name
    FLAGS.actual_log_root = FLAGS.log_root
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)

    if FLAGS.convert_to_importance_model:
        convert_to_importance_model()
        # FLAGS.convert_to_coverage_model = True
    if FLAGS.word_imp_reg:
        if FLAGS.coverage:
            raise Exception('Importance loss does not work at the same time with coverage loss yet. Need to modify the total_loss in model.py.')
        FLAGS.log_root += '_imp' + str(FLAGS.imp_loss_wt)
        if FLAGS.imp_loss_oneminus:
            FLAGS.log_root += '_oneminus'

    print(util.bcolors.OKGREEN + "Experiment path: " + FLAGS.log_root + util.bcolors.ENDC)

    if FLAGS.dataset_name == 'duc_2004':
        vocab = Vocab(FLAGS.vocab_path + '_' + 'cnn_dm', FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary
    else:
        vocab_datasets = [os.path.basename(file_path).split('vocab_')[1] for file_path in glob.glob(FLAGS.vocab_path + '_*')]
        original_dataset_name = [file_name for file_name in vocab_datasets if file_name in FLAGS.dataset_name]
        if len(original_dataset_name) > 1:
            raise Exception('Too many choices for vocab file')
        if len(original_dataset_name) < 1:
            raise Exception('No vocab file for dataset created. Run make_vocab.py --dataset_name=<my original dataset name>')
        original_dataset_name = original_dataset_name[0]
        FLAGS.original_dataset_name = original_dataset_name
        vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name, FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary


    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode!='decode':
        raise Exception("The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    # hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
    #                'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps',
    #                'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'lambdamart_input', 'pg_mmr', 'singles_and_pairs', 'skip_with_less_than_3', 'ssi_data_path',
    #                'dataset_name', 'word_imp_reg', 'imp_loss_wt', 'tag_tokens']
    hparam_list = [item for item in list(FLAGS.flag_values_dict().keys()) if item != '?']
    hps_dict = {}
    for key,val in FLAGS.__flags.items(): # for each flag
        if key in hparam_list: # if it's in the list
            hps_dict[key] = val.value # add it to the dict
    hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict)

    if FLAGS.pg_mmr:

        # Fit the TFIDF vectorizer if not already fitted
        if FLAGS.importance_fn == 'tfidf':
            tfidf_model_path = os.path.join(FLAGS.actual_log_root, 'tfidf_vectorizer', FLAGS.original_dataset_name + '.dill')
            if not os.path.exists(tfidf_model_path):
                print(('No TFIDF vectorizer model file found at %s, so fitting the model now.' % tfidf_model_path))
                tfidf_vectorizer = fit_tfidf_vectorizer(hps, vocab)
                with open(tfidf_model_path, 'wb') as f:
                    dill.dump(tfidf_vectorizer, f)

        # Train the SVR model on the CNN validation set if not already trained
        if FLAGS.importance_fn == 'svr':
            save_path = os.path.join(FLAGS.data_root, 'svr_training_data')
            importance_model_path = os.path.join(FLAGS.actual_log_root, 'svr.pickle')
            dataset_split = 'val'
            if not os.path.exists(importance_model_path):
                if not os.path.exists(save_path) or len(os.listdir(save_path)) == 0:
                    print(('No importance_feature instances found at %s so creating it from raw data.' % save_path))
                    decode_model_hps = hps._replace(
                        max_dec_steps=1, batch_size=100, mode='calc_features')  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
                    cnn_dm_train_data_path = os.path.join(FLAGS.data_root, 'cnn_500_dm_500', dataset_split + '*')
                    batcher = Batcher(cnn_dm_train_data_path, vocab, decode_model_hps, single_pass=FLAGS.single_pass, cnn_500_dm_500=True)
                    calc_features(cnn_dm_train_data_path, decode_model_hps, vocab, batcher, save_path)

                print(('No importance_feature SVR model found at %s so training it now.' % importance_model_path))
                features_list = importance_features.get_features_list(True)
                sent_reps = importance_features.load_data(os.path.join(save_path, dataset_split + '*'), -1)
                print('Loaded %d sentences representations' % len(sent_reps))
                x_y = importance_features.features_to_array(sent_reps, features_list)
                train_x, train_y = x_y[:,:-1], x_y[:,-1]
                svr_model = importance_features.run_training(train_x, train_y)
                with open(importance_model_path, 'wb') as f:
                    pickle.dump(svr_model, f)

    # Create a batcher object that will create minibatches of data
    batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass)

    tf.set_random_seed(113) # a seed value for randomness

    # Start decoding on multi-document inputs
    if hps.mode == 'train':
        print("creating model...")
        model = SummarizationModel(hps, vocab)
        setup_training(model, batcher)
    elif hps.mode == 'eval':
        model = SummarizationModel(hps, vocab)
        run_eval(model, batcher, vocab)
    elif hps.mode == 'decode':
        decode_model_hps = hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
        model = SummarizationModel(decode_model_hps, vocab)
        decoder = BeamSearchDecoder(model, batcher, vocab)
        decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
        # while True:
        #     a=0
    else:
        raise ValueError("The 'mode' flag must be one of train/eval/decode")
예제 #3
0
def main(unused_argv):
    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    if FLAGS.singles_and_pairs == 'both':
        FLAGS.exp_name = FLAGS.exp_name + '_both'
        exp_name = _exp_name + '_both'
        dataset_articles = _dataset_articles
    else:
        FLAGS.exp_name = FLAGS.exp_name + '_singles'
        exp_name = _exp_name + '_singles'
        dataset_articles = _dataset_articles + '_singles'
    my_log_dir = os.path.join(log_dir, FLAGS.ssi_exp_name)

    print('Running statistics on %s' % FLAGS.exp_name)

    if FLAGS.dataset_name != "":
        FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name,
                                       FLAGS.dataset_split + '*')
    if not os.path.exists(os.path.join(
            FLAGS.data_root, FLAGS.dataset_name)) or len(
                os.listdir(os.path.join(FLAGS.data_root,
                                        FLAGS.dataset_name))) == 0:
        print(('No TF example data found at %s so creating it from raw data.' %
               os.path.join(FLAGS.data_root, FLAGS.dataset_name)))
        convert_data.process_dataset(FLAGS.dataset_name)

    logging.set_verbosity(
        logging.INFO)  # choose what level of logging you want
    logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name
    FLAGS.actual_log_root = FLAGS.log_root
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)

    vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size)  # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode != 'decode':
        raise Exception(
            "The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    hparam_list = [
        'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag',
        'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim',
        'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage',
        'cov_loss_wt', 'pointer_gen', 'lambdamart_input'
    ]
    hps_dict = {}
    for key, val in FLAGS.__flags.items():  # for each flag
        if key in hparam_list:  # if it's in the list
            hps_dict[key] = val.value  # add it to the dict
    hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict)

    tf.set_random_seed(113)  # a seed value for randomness

    decode_model_hps = hps._replace(
        max_dec_steps=1
    )  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries

    if len(unused_argv
           ) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(data_dir, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))

    with open(os.path.join(my_log_dir, 'ssi.pkl')) as f:
        ssi_list = pickle.load(f)

    total = len(source_files
                ) * 1000 if 'cnn' or 'newsroom' in dataset_articles else len(
                    source_files)
    example_generator = data.example_generator(source_dir + '/' +
                                               dataset_split + '*',
                                               True,
                                               False,
                                               should_check_valid=False)
    # batcher = Batcher(None, vocab, hps, single_pass=FLAGS.single_pass)
    model = SummarizationModel(decode_model_hps, vocab)
    decoder = BeamSearchDecoder(model, None, vocab)
    decoder.decode_iteratively(example_generator, total, names_to_types,
                               ssi_list, hps)

    a = 0
예제 #4
0
def main(unused_argv):
    if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)

    extractor = 'bert' if FLAGS.use_bert else 'lambdamart'
    if FLAGS.cnn_dm_pg:
        pretrained_dataset = 'cnn_dm'
    elif FLAGS.websplit:
        pretrained_dataset = 'websplit'
    else:
        pretrained_dataset = FLAGS.dataset_name
    if FLAGS.dataset_name == 'duc_2004':
        pretrained_dataset = 'cnn_dm'
    if FLAGS.singles_and_pairs == 'both':
        FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_both'
        FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_both')
        dataset_articles = FLAGS.dataset_name
    else:
        FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_singles'
        FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_singles')
        dataset_articles = FLAGS.dataset_name + '_singles'
    if FLAGS.word_imp_reg:
        FLAGS.pretrained_path += '_imp' + str(FLAGS.imp_loss_wt)
        FLAGS.exp_name += '_imp' + str(FLAGS.imp_loss_wt)
        if FLAGS.imp_loss_oneminus:
            FLAGS.pretrained_path += '_oneminus'
            FLAGS.exp_name += '_oneminus'
    if FLAGS.sep:
        FLAGS.pretrained_path += '_sep'
        FLAGS.exp_name += '_sep'
    if FLAGS.tag_tokens:
        FLAGS.pretrained_path += '_tag'
        FLAGS.exp_name += '_tag' + str(FLAGS.tag_loss_wt)



    bert_suffix = ''
    # if FLAGS.use_bert:
    #     if FLAGS.sentemb:
    #         FLAGS.exp_name += '_sentemb'
    #         bert_suffix += '_sentemb'
    #     if FLAGS.artemb:
    #         FLAGS.exp_name += '_artemb'
    #         bert_suffix += '_artemb'
    #     if FLAGS.plushidden:
    #         FLAGS.exp_name += '_plushidden'
    #         bert_suffix += '_plushidden'
    if FLAGS.tag_tokens:
        bert_suffix += '_tag' + str(FLAGS.tag_loss_wt)
    else:
        bert_suffix += '_tag' + '0.0'
    if FLAGS.upper_bound:
        FLAGS.exp_name = FLAGS.exp_name + '_upperbound'
        ssi_list = None     # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth)
    else:
        my_log_dir = os.path.join(log_dir, '%s_%s_%s%s' % (FLAGS.dataset_name, extractor, FLAGS.singles_and_pairs, bert_suffix))
        print(util.bcolors.OKGREEN + "BERT path: " + my_log_dir + util.bcolors.ENDC)
        with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f:
            ssi_list = pickle.load(f)
        FLAGS.ssi_data_path = my_log_dir
    if FLAGS.cnn_dm_pg:
        FLAGS.exp_name = FLAGS.exp_name + '_cnntrained'
    if FLAGS.websplit:
        FLAGS.exp_name = FLAGS.exp_name + '_websplittrained'
    if FLAGS.first_intact:
        FLAGS.exp_name = FLAGS.exp_name + '_firstintact'




    print('Running statistics on %s' % FLAGS.exp_name)

    if FLAGS.dataset_name != "":
        FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*')
    if not os.path.exists(os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len(os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0:
        print(('No TF example data found at %s so creating it from raw data.' % os.path.join(FLAGS.data_root, FLAGS.dataset_name)))
        convert_data.process_dataset(FLAGS.dataset_name)

    logging.set_verbosity(logging.INFO) # choose what level of logging you want
    logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary
    FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name
    FLAGS.actual_log_root = FLAGS.log_root
    FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)

    print(util.bcolors.OKGREEN + "Experiment path: " + FLAGS.log_root + util.bcolors.ENDC)


    if FLAGS.dataset_name == 'duc_2004':
        vocab = Vocab(FLAGS.vocab_path + '_' + 'cnn_dm', FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary
    else:
        vocab_datasets = [os.path.basename(file_path).split('vocab_')[1] for file_path in glob.glob(FLAGS.vocab_path + '_*')]
        original_dataset_name = [file_name for file_name in vocab_datasets if file_name in FLAGS.dataset_name]
        if len(original_dataset_name) > 1:
            raise Exception('Too many choices for vocab file')
        if len(original_dataset_name) < 1:
            raise Exception('No vocab file for dataset created. Run make_vocab.py --dataset_name=<my original dataset name>')
        original_dataset_name = original_dataset_name[0]
        FLAGS.original_dataset_name = original_dataset_name
        vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name, FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary

    # If in decode mode, set batch_size = beam_size
    # Reason: in decode mode, we decode one example at a time.
    # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses.
    if FLAGS.mode == 'decode':
        FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode!='decode':
        raise Exception("The single_pass flag should only be True in decode mode")

    # Make a namedtuple hps, containing the values of the hyperparameters that the model needs
    # hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std',
    #                'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps',
    #                'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'lambdamart_input', 'pg_mmr', 'singles_and_pairs', 'skip_with_less_than_3',
    #                'ssi_data_path', 'word_imp_reg', 'imp_loss_wt']
    hparam_list = [item for item in list(FLAGS.flag_values_dict().keys()) if item != '?']
    hps_dict = {}
    for key,val in FLAGS.__flags.items(): # for each flag
        if key in hparam_list: # if it's in the list
            hps_dict[key] = val.value # add it to the dict
    hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict)

    tf.set_random_seed(113) # a seed value for randomness

    decode_model_hps = hps._replace(
        max_dec_steps=1)  # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries

    if len(unused_argv) != 1:  # prints a message if you've entered flags incorrectly
        raise Exception("Problem with flags: %s" % unused_argv)
    start_time = time.time()
    np.random.seed(random_seed)
    source_dir = os.path.join(FLAGS.data_root, dataset_articles)
    source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*'))

    total = len(source_files) * 1000 if 'cnn' in dataset_articles or 'xsum' in dataset_articles else len(source_files)
    example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False,
                                               should_check_valid=False)
    # batcher = Batcher(None, vocab, hps, single_pass=FLAGS.single_pass)
    model = SummarizationModel(decode_model_hps, vocab)
    decoder = BeamSearchDecoder(model, None, vocab)
    decoder.decode_iteratively(example_generator, total, names_to_types, ssi_list, hps)

    # num_outside = []
    # for example_idx, example in enumerate(tqdm(example_generator, total=total)):
    #     raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs = util.unpack_tf_example(
    #         example, names_to_types)
    #     article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents]
    #     cur_token_idx = 0
    #     for sent_idx, sent_tokens in enumerate(article_sent_tokens):
    #         for token in sent_tokens:
    #             cur_token_idx += 1
    #             if cur_token_idx >= 400:
    #                 sent_idx_at_400 = sent_idx
    #                 break
    #         if cur_token_idx >= 400:
    #             break
    #
    #     my_num_outside = 0
    #     for ssi in groundtruth_similar_source_indices_list:
    #         for source_idx in ssi:
    #             if source_idx >= sent_idx_at_400:
    #                 my_num_outside += 1
    #     num_outside.append(my_num_outside)
    # print "num_outside = %d" % np.mean(num_outside)


    a=0