def main(unused_argv): if len( unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.dataset_name != "": FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*') if not os.path.exists( os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len( os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0: print( 'No TF example data found at %s so creating it from raw data.' % os.path.join( FLAGS.data_root, FLAGS.dataset_name)) convert_data.process_dataset(FLAGS.dataset_name) logging.set_verbosity(logging.INFO) # choose what level of logging you want logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name FLAGS.actual_log_root = FLAGS.log_root FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode != 'decode': raise Exception( "The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen'] hps_dict = {} for key, val in FLAGS.__flags.iteritems(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val.value # add it to the dict hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) if FLAGS.pg_mmr or FLAGS.pg_mmr_sim or FLAGS.pg_mmr_diff: # Fit the TFIDF vectorizer if not already fitted if FLAGS.importance_fn == 'tfidf': tfidf_model_path = os.path.join(FLAGS.actual_log_root, 'tfidf_vectorizer', FLAGS.dataset_name + '.dill') if not os.path.exists(tfidf_model_path): print( 'No TFIDF vectorizer model file found at %s, so fitting the model now.' % tfidf_model_path) tfidf_vectorizer = fit_tfidf_vectorizer(hps, vocab) with open(tfidf_model_path, 'wb') as f: dill.dump(tfidf_vectorizer, f) # Train the SVR model on the CNN validation set if not already trained if FLAGS.importance_fn == 'svr': save_path = os.path.join(FLAGS.data_root, 'svr_training_data') importance_model_path = os.path.join(FLAGS.actual_log_root, 'svr.pickle') dataset_split = 'val' if not os.path.exists(importance_model_path): if not os.path.exists(save_path) or len( os.listdir(save_path)) == 0: print( 'No importance_feature instances found at %s so creating it from raw data.' % save_path) decode_model_hps = hps._replace( max_dec_steps=1, batch_size=100, mode='calc_features') # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries cnn_dm_train_data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, dataset_split + '*') batcher = Batcher(cnn_dm_train_data_path, vocab, decode_model_hps, single_pass=FLAGS.single_pass, cnn_500_dm_500=False) calc_features(cnn_dm_train_data_path, decode_model_hps, vocab, batcher, save_path) print( 'No importance_feature SVR model found at %s so training it now.' % importance_model_path) features_list = importance_features.get_features_list(True) sent_reps = importance_features.load_data( os.path.join(save_path, dataset_split + '*'), -1) print 'Loaded %d sentences representations' % len(sent_reps) x_y = importance_features.features_to_array(sent_reps, features_list) train_x, train_y = x_y[:, :-1], x_y[:, -1] svr_model = importance_features.run_training(train_x, train_y) with open(importance_model_path, 'wb') as f: cPickle.dump(svr_model, f) # Create a batcher object that will create minibatches of data batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass) tf.set_random_seed(111) # a seed value for randomness # Start decoding on multi-document inputs if hps.mode == 'decode': decode_model_hps = hps._replace( max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries model = SummarizationModel(decode_model_hps, vocab) decoder = BeamSearchDecoder(model, batcher, vocab) decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once) else: raise ValueError("The 'mode' flag must be one of train/eval/decode")
def main(unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) # if '_sent' in FLAGS.dataset_name: # FLAGS.data_root = os.path.expanduser('~') + '/data/tf_data/with_coref_and_tag_tokens' if FLAGS.pg_mmr: FLAGS.data_root = os.path.expanduser('~') + "/data/tf_data/with_coref_and_ssi" if FLAGS.dataset_name != "": FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*') if FLAGS.dataset_name in kaiqiang_dataset_names: FLAGS.skip_with_less_than_3 = False if not os.path.exists(os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len(os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0: print(('No TF example data found at %s so creating it from raw data.' % os.path.join(FLAGS.data_root, FLAGS.dataset_name))) convert_data.process_dataset(FLAGS.dataset_name) if FLAGS.mode == 'decode': extractor = '_bert' if FLAGS.use_bert else '_lambdamart' FLAGS.use_pretrained = True FLAGS.single_pass = True else: extractor = '' pretrained_dataset = FLAGS.dataset_name if FLAGS.dataset_name == 'duc_2004': pretrained_dataset = 'cnn_dm' if FLAGS.pg_mmr: FLAGS.exp_name += '_pgmmr' if FLAGS.singles_and_pairs == 'both': FLAGS.exp_name = FLAGS.exp_name + extractor + '_both' if FLAGS.mode == 'decode': FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_pgmmr_both') dataset_articles = FLAGS.dataset_name elif FLAGS.singles_and_pairs == 'singles': FLAGS.exp_name = FLAGS.exp_name + extractor + '_singles' if FLAGS.mode == 'decode': FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_pgmmr_singles') dataset_articles = FLAGS.dataset_name + '_singles' if FLAGS.notrain: FLAGS.exp_name += '_notrain' FLAGS.pretrained_path = original_pretrained_path[FLAGS.dataset_name] if FLAGS.finetune: FLAGS.exp_name += '_finetune' if FLAGS.mode == 'decode': FLAGS.pretrained_path += '_finetune' if FLAGS.sep: FLAGS.exp_name += '_sep' if FLAGS.tag_tokens: FLAGS.exp_name += '_tag' extractor = 'bert' if FLAGS.use_bert else 'lambdamart' bert_suffix = '' # if FLAGS.use_bert: # if FLAGS.sentemb: # bert_suffix += '_sentemb' # if FLAGS.artemb: # bert_suffix += '_artemb' # if FLAGS.plushidden: # bert_suffix += '_plushidden' # if FLAGS.mode == 'decode': # if FLAGS.sentemb: # FLAGS.exp_name += '_sentemb' # if FLAGS.artemb: # FLAGS.exp_name += '_artemb' # if FLAGS.plushidden: # FLAGS.exp_name += '_plushidden' if FLAGS.upper_bound: FLAGS.exp_name = FLAGS.exp_name + '_upperbound' ssi_list = None # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) else: if FLAGS.mode == 'decode': my_log_dir = os.path.join(log_dir, '%s_%s_%s%s' % (FLAGS.dataset_name, extractor, FLAGS.singles_and_pairs, bert_suffix)) FLAGS.ssi_data_path = my_log_dir logging.set_verbosity(logging.INFO) # choose what level of logging you want logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name FLAGS.actual_log_root = FLAGS.log_root FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) if FLAGS.convert_to_importance_model: convert_to_importance_model() # FLAGS.convert_to_coverage_model = True if FLAGS.word_imp_reg: if FLAGS.coverage: raise Exception('Importance loss does not work at the same time with coverage loss yet. Need to modify the total_loss in model.py.') FLAGS.log_root += '_imp' + str(FLAGS.imp_loss_wt) if FLAGS.imp_loss_oneminus: FLAGS.log_root += '_oneminus' print(util.bcolors.OKGREEN + "Experiment path: " + FLAGS.log_root + util.bcolors.ENDC) if FLAGS.dataset_name == 'duc_2004': vocab = Vocab(FLAGS.vocab_path + '_' + 'cnn_dm', FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary else: vocab_datasets = [os.path.basename(file_path).split('vocab_')[1] for file_path in glob.glob(FLAGS.vocab_path + '_*')] original_dataset_name = [file_name for file_name in vocab_datasets if file_name in FLAGS.dataset_name] if len(original_dataset_name) > 1: raise Exception('Too many choices for vocab file') if len(original_dataset_name) < 1: raise Exception('No vocab file for dataset created. Run make_vocab.py --dataset_name=<my original dataset name>') original_dataset_name = original_dataset_name[0] FLAGS.original_dataset_name = original_dataset_name vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name, FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode!='decode': raise Exception("The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs # hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', # 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', # 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'lambdamart_input', 'pg_mmr', 'singles_and_pairs', 'skip_with_less_than_3', 'ssi_data_path', # 'dataset_name', 'word_imp_reg', 'imp_loss_wt', 'tag_tokens'] hparam_list = [item for item in list(FLAGS.flag_values_dict().keys()) if item != '?'] hps_dict = {} for key,val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val.value # add it to the dict hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict) if FLAGS.pg_mmr: # Fit the TFIDF vectorizer if not already fitted if FLAGS.importance_fn == 'tfidf': tfidf_model_path = os.path.join(FLAGS.actual_log_root, 'tfidf_vectorizer', FLAGS.original_dataset_name + '.dill') if not os.path.exists(tfidf_model_path): print(('No TFIDF vectorizer model file found at %s, so fitting the model now.' % tfidf_model_path)) tfidf_vectorizer = fit_tfidf_vectorizer(hps, vocab) with open(tfidf_model_path, 'wb') as f: dill.dump(tfidf_vectorizer, f) # Train the SVR model on the CNN validation set if not already trained if FLAGS.importance_fn == 'svr': save_path = os.path.join(FLAGS.data_root, 'svr_training_data') importance_model_path = os.path.join(FLAGS.actual_log_root, 'svr.pickle') dataset_split = 'val' if not os.path.exists(importance_model_path): if not os.path.exists(save_path) or len(os.listdir(save_path)) == 0: print(('No importance_feature instances found at %s so creating it from raw data.' % save_path)) decode_model_hps = hps._replace( max_dec_steps=1, batch_size=100, mode='calc_features') # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries cnn_dm_train_data_path = os.path.join(FLAGS.data_root, 'cnn_500_dm_500', dataset_split + '*') batcher = Batcher(cnn_dm_train_data_path, vocab, decode_model_hps, single_pass=FLAGS.single_pass, cnn_500_dm_500=True) calc_features(cnn_dm_train_data_path, decode_model_hps, vocab, batcher, save_path) print(('No importance_feature SVR model found at %s so training it now.' % importance_model_path)) features_list = importance_features.get_features_list(True) sent_reps = importance_features.load_data(os.path.join(save_path, dataset_split + '*'), -1) print('Loaded %d sentences representations' % len(sent_reps)) x_y = importance_features.features_to_array(sent_reps, features_list) train_x, train_y = x_y[:,:-1], x_y[:,-1] svr_model = importance_features.run_training(train_x, train_y) with open(importance_model_path, 'wb') as f: pickle.dump(svr_model, f) # Create a batcher object that will create minibatches of data batcher = Batcher(FLAGS.data_path, vocab, hps, single_pass=FLAGS.single_pass) tf.set_random_seed(113) # a seed value for randomness # Start decoding on multi-document inputs if hps.mode == 'train': print("creating model...") model = SummarizationModel(hps, vocab) setup_training(model, batcher) elif hps.mode == 'eval': model = SummarizationModel(hps, vocab) run_eval(model, batcher, vocab) elif hps.mode == 'decode': decode_model_hps = hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries model = SummarizationModel(decode_model_hps, vocab) decoder = BeamSearchDecoder(model, batcher, vocab) decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once) # while True: # a=0 else: raise ValueError("The 'mode' flag must be one of train/eval/decode")
def main(unused_argv): if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) if FLAGS.singles_and_pairs == 'both': FLAGS.exp_name = FLAGS.exp_name + '_both' exp_name = _exp_name + '_both' dataset_articles = _dataset_articles else: FLAGS.exp_name = FLAGS.exp_name + '_singles' exp_name = _exp_name + '_singles' dataset_articles = _dataset_articles + '_singles' my_log_dir = os.path.join(log_dir, FLAGS.ssi_exp_name) print('Running statistics on %s' % FLAGS.exp_name) if FLAGS.dataset_name != "": FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*') if not os.path.exists(os.path.join( FLAGS.data_root, FLAGS.dataset_name)) or len( os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0: print(('No TF example data found at %s so creating it from raw data.' % os.path.join(FLAGS.data_root, FLAGS.dataset_name))) convert_data.process_dataset(FLAGS.dataset_name) logging.set_verbosity( logging.INFO) # choose what level of logging you want logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name FLAGS.actual_log_root = FLAGS.log_root FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode != 'decode': raise Exception( "The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = [ 'mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'lambdamart_input' ] hps_dict = {} for key, val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val.value # add it to the dict hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict) tf.set_random_seed(113) # a seed value for randomness decode_model_hps = hps._replace( max_dec_steps=1 ) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries if len(unused_argv ) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) start_time = time.time() np.random.seed(random_seed) source_dir = os.path.join(data_dir, dataset_articles) source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) with open(os.path.join(my_log_dir, 'ssi.pkl')) as f: ssi_list = pickle.load(f) total = len(source_files ) * 1000 if 'cnn' or 'newsroom' in dataset_articles else len( source_files) example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) # batcher = Batcher(None, vocab, hps, single_pass=FLAGS.single_pass) model = SummarizationModel(decode_model_hps, vocab) decoder = BeamSearchDecoder(model, None, vocab) decoder.decode_iteratively(example_generator, total, names_to_types, ssi_list, hps) a = 0
def main(unused_argv): if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) extractor = 'bert' if FLAGS.use_bert else 'lambdamart' if FLAGS.cnn_dm_pg: pretrained_dataset = 'cnn_dm' elif FLAGS.websplit: pretrained_dataset = 'websplit' else: pretrained_dataset = FLAGS.dataset_name if FLAGS.dataset_name == 'duc_2004': pretrained_dataset = 'cnn_dm' if FLAGS.singles_and_pairs == 'both': FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_both' FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_both') dataset_articles = FLAGS.dataset_name else: FLAGS.exp_name = FLAGS.dataset_name + '_' + FLAGS.exp_name + extractor + '_singles' FLAGS.pretrained_path = os.path.join(FLAGS.log_root, pretrained_dataset + '_singles') dataset_articles = FLAGS.dataset_name + '_singles' if FLAGS.word_imp_reg: FLAGS.pretrained_path += '_imp' + str(FLAGS.imp_loss_wt) FLAGS.exp_name += '_imp' + str(FLAGS.imp_loss_wt) if FLAGS.imp_loss_oneminus: FLAGS.pretrained_path += '_oneminus' FLAGS.exp_name += '_oneminus' if FLAGS.sep: FLAGS.pretrained_path += '_sep' FLAGS.exp_name += '_sep' if FLAGS.tag_tokens: FLAGS.pretrained_path += '_tag' FLAGS.exp_name += '_tag' + str(FLAGS.tag_loss_wt) bert_suffix = '' # if FLAGS.use_bert: # if FLAGS.sentemb: # FLAGS.exp_name += '_sentemb' # bert_suffix += '_sentemb' # if FLAGS.artemb: # FLAGS.exp_name += '_artemb' # bert_suffix += '_artemb' # if FLAGS.plushidden: # FLAGS.exp_name += '_plushidden' # bert_suffix += '_plushidden' if FLAGS.tag_tokens: bert_suffix += '_tag' + str(FLAGS.tag_loss_wt) else: bert_suffix += '_tag' + '0.0' if FLAGS.upper_bound: FLAGS.exp_name = FLAGS.exp_name + '_upperbound' ssi_list = None # this is if we are doing the upper bound evaluation (ssi_list comes straight from the groundtruth) else: my_log_dir = os.path.join(log_dir, '%s_%s_%s%s' % (FLAGS.dataset_name, extractor, FLAGS.singles_and_pairs, bert_suffix)) print(util.bcolors.OKGREEN + "BERT path: " + my_log_dir + util.bcolors.ENDC) with open(os.path.join(my_log_dir, 'ssi.pkl'), 'rb') as f: ssi_list = pickle.load(f) FLAGS.ssi_data_path = my_log_dir if FLAGS.cnn_dm_pg: FLAGS.exp_name = FLAGS.exp_name + '_cnntrained' if FLAGS.websplit: FLAGS.exp_name = FLAGS.exp_name + '_websplittrained' if FLAGS.first_intact: FLAGS.exp_name = FLAGS.exp_name + '_firstintact' print('Running statistics on %s' % FLAGS.exp_name) if FLAGS.dataset_name != "": FLAGS.data_path = os.path.join(FLAGS.data_root, FLAGS.dataset_name, FLAGS.dataset_split + '*') if not os.path.exists(os.path.join(FLAGS.data_root, FLAGS.dataset_name)) or len(os.listdir(os.path.join(FLAGS.data_root, FLAGS.dataset_name))) == 0: print(('No TF example data found at %s so creating it from raw data.' % os.path.join(FLAGS.data_root, FLAGS.dataset_name))) convert_data.process_dataset(FLAGS.dataset_name) logging.set_verbosity(logging.INFO) # choose what level of logging you want logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode)) # Change log_root to FLAGS.log_root/FLAGS.exp_name and create the dir if necessary FLAGS.exp_name = FLAGS.exp_name if FLAGS.exp_name != '' else FLAGS.dataset_name FLAGS.actual_log_root = FLAGS.log_root FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name) print(util.bcolors.OKGREEN + "Experiment path: " + FLAGS.log_root + util.bcolors.ENDC) if FLAGS.dataset_name == 'duc_2004': vocab = Vocab(FLAGS.vocab_path + '_' + 'cnn_dm', FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary else: vocab_datasets = [os.path.basename(file_path).split('vocab_')[1] for file_path in glob.glob(FLAGS.vocab_path + '_*')] original_dataset_name = [file_name for file_name in vocab_datasets if file_name in FLAGS.dataset_name] if len(original_dataset_name) > 1: raise Exception('Too many choices for vocab file') if len(original_dataset_name) < 1: raise Exception('No vocab file for dataset created. Run make_vocab.py --dataset_name=<my original dataset name>') original_dataset_name = original_dataset_name[0] FLAGS.original_dataset_name = original_dataset_name vocab = Vocab(FLAGS.vocab_path + '_' + original_dataset_name, FLAGS.vocab_size, add_sep=FLAGS.sep) # create a vocabulary # If in decode mode, set batch_size = beam_size # Reason: in decode mode, we decode one example at a time. # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode!='decode': raise Exception("The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs # hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', # 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', # 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'lambdamart_input', 'pg_mmr', 'singles_and_pairs', 'skip_with_less_than_3', # 'ssi_data_path', 'word_imp_reg', 'imp_loss_wt'] hparam_list = [item for item in list(FLAGS.flag_values_dict().keys()) if item != '?'] hps_dict = {} for key,val in FLAGS.__flags.items(): # for each flag if key in hparam_list: # if it's in the list hps_dict[key] = val.value # add it to the dict hps = namedtuple("HParams", list(hps_dict.keys()))(**hps_dict) tf.set_random_seed(113) # a seed value for randomness decode_model_hps = hps._replace( max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries if len(unused_argv) != 1: # prints a message if you've entered flags incorrectly raise Exception("Problem with flags: %s" % unused_argv) start_time = time.time() np.random.seed(random_seed) source_dir = os.path.join(FLAGS.data_root, dataset_articles) source_files = sorted(glob.glob(source_dir + '/' + dataset_split + '*')) total = len(source_files) * 1000 if 'cnn' in dataset_articles or 'xsum' in dataset_articles else len(source_files) example_generator = data.example_generator(source_dir + '/' + dataset_split + '*', True, False, should_check_valid=False) # batcher = Batcher(None, vocab, hps, single_pass=FLAGS.single_pass) model = SummarizationModel(decode_model_hps, vocab) decoder = BeamSearchDecoder(model, None, vocab) decoder.decode_iteratively(example_generator, total, names_to_types, ssi_list, hps) # num_outside = [] # for example_idx, example in enumerate(tqdm(example_generator, total=total)): # raw_article_sents, groundtruth_similar_source_indices_list, groundtruth_summary_text, corefs = util.unpack_tf_example( # example, names_to_types) # article_sent_tokens = [util.process_sent(sent) for sent in raw_article_sents] # cur_token_idx = 0 # for sent_idx, sent_tokens in enumerate(article_sent_tokens): # for token in sent_tokens: # cur_token_idx += 1 # if cur_token_idx >= 400: # sent_idx_at_400 = sent_idx # break # if cur_token_idx >= 400: # break # # my_num_outside = 0 # for ssi in groundtruth_similar_source_indices_list: # for source_idx in ssi: # if source_idx >= sent_idx_at_400: # my_num_outside += 1 # num_outside.append(my_num_outside) # print "num_outside = %d" % np.mean(num_outside) a=0