def main(): data, vocab, htree = get_data(datapath, FLAGS.test_size, FLAGS.val_size) train_x, train_y = data[0], data[1] val_x, val_y = data[2], data[3] test_x, test_y = data[4], data[5] idx2word, word2idx = vocab[0], vocab[1] # train_x, train_y, val_x, val_y, test_x, test_y, idx2word, word2idx = \ # get_data(datapath, FLAGS.test_size, FLAGS.val_size) FLAGS.vocab_len = len(idx2word) # any key would give the correct max path len FLAGS.max_depth = len(htree['<pad>']) FLAGS.num_samples = train_x.shape[0] FLAGS.sentence_len = train_x.shape[-1] FLAGS.timesteps = train_y.shape[-1] FLAGS.embedding_method = args.load_embed FLAGS.embedding_size = int(args.dim_embed) FLAGS.dataset_name = dataset # if dataset != "bbc_news": # FLAGS.batch_size = 1 if args.attlayer == "concat": FLAGS.multi_concat = False print('news headlines format:', train_y.shape) print('news descriptions format:', train_x.shape) print('number of tokens in the vocabulary', FLAGS.vocab_len) print("huffman tree max depth ", max([len(path) for path in htree.values()])) Model = SummarizationModel(FLAGS, idx2word, htree) if args.train is not None: # Trains the model Model.train(train_x, train_y, val_x, val_y) else: # Evaluates the model Model.eval(test_x, test_y)
def main(args): main_start = time.time() tf.set_random_seed(2019) random.seed(2019) np.random.seed(2019) if len(args) != 1: raise Exception('Problem with flags: %s' % args) # Correcting a few flags for test/eval mode. if FLAGS.mode != 'train': FLAGS.batch_size = FLAGS.beam_size FLAGS.bs_dec_steps = FLAGS.dec_steps if FLAGS.model.lower() != "tx": FLAGS.dec_steps = 1 assert FLAGS.mode == 'train' or FLAGS.batch_size == FLAGS.beam_size, \ "In test mode, batch size should be equal to beam size." assert FLAGS.mode == 'train' or FLAGS.dec_steps == 1 or FLAGS.model.lower() == "tx", \ "In test mode, no. of decoder steps should be one." os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = ",".join( str(gpu_id) for gpu_id in FLAGS.GPUs) if not os.path.exists(FLAGS.PathToCheckpoint): os.makedirs(FLAGS.PathToCheckpoint) if FLAGS.mode == "test" and not os.path.exists(FLAGS.PathToResults): os.makedirs(FLAGS.PathToResults) os.makedirs(FLAGS.PathToResults + 'predictions') os.makedirs(FLAGS.PathToResults + 'groundtruths') if FLAGS.mode == 'eval': eval_model(FLAGS.PathToResults) else: start = time.time() vocab = Vocab(max_vocab_size=FLAGS.vocab_size, emb_dim=FLAGS.dim, dataset_path=FLAGS.PathToDataset, glove_path=FLAGS.PathToGlove, vocab_path=FLAGS.PathToVocab, lookup_path=FLAGS.PathToLookups) if FLAGS.model.lower() == "plain": print("Setting up the plain model.\n") data = DataGenerator(path_to_dataset=FLAGS.PathToDataset, max_inp_seq_len=FLAGS.enc_steps, max_out_seq_len=FLAGS.dec_steps, vocab=vocab, use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample) summarizer = SummarizationModel(vocab, data) elif FLAGS.model.lower() == "hier": print("Setting up the hier model.\n") data = DataGeneratorHier( path_to_dataset=FLAGS.PathToDataset, max_inp_sent=FLAGS.max_enc_sent, max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent, max_out_tok=FLAGS.dec_steps, vocab=vocab, use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample) summarizer = SummarizationModelHier(vocab, data) elif FLAGS.model.lower() == "rlhier": print("Setting up the Hier RL model.\n") data = DataGeneratorHier( path_to_dataset=FLAGS.PathToDataset, max_inp_sent=FLAGS.max_enc_sent, max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent, max_out_tok=FLAGS.dec_steps, vocab=vocab, use_pgen=FLAGS.use_pgen, use_sample=FLAGS.sample) summarizer = SummarizationModelHierSC(vocab, data) else: raise ValueError( "model flag should be either of plain/hier/bayesian/shared!! \n" ) end = time.time() print( "Setting up vocab, data and model took {:.2f} sec.".format(end - start)) summarizer.build_graph() if FLAGS.mode == 'train': summarizer.train() elif FLAGS.mode == "test": summarizer.test() else: raise ValueError("mode should be either train/test!! \n") main_end = time.time() print("Total time elapsed: %.2f \n" % (main_end - main_start))