Example #1
0
def main():
        data, vocab, htree = get_data(datapath, FLAGS.test_size, FLAGS.val_size)

        train_x, train_y = data[0], data[1]
        val_x, val_y = data[2], data[3]
        test_x, test_y = data[4], data[5]

        idx2word, word2idx = vocab[0], vocab[1]
        # train_x, train_y, val_x, val_y, test_x, test_y, idx2word, word2idx = \
        #       get_data(datapath, FLAGS.test_size, FLAGS.val_size)

        FLAGS.vocab_len = len(idx2word)

        # any key would give the correct max path len
        FLAGS.max_depth = len(htree['<pad>'])
        FLAGS.num_samples = train_x.shape[0]
        FLAGS.sentence_len = train_x.shape[-1]
        FLAGS.timesteps = train_y.shape[-1]
        FLAGS.embedding_method = args.load_embed
        FLAGS.embedding_size = int(args.dim_embed)
        FLAGS.dataset_name = dataset

        # if dataset != "bbc_news":
        #       FLAGS.batch_size = 1

        if args.attlayer == "concat":
                FLAGS.multi_concat = False

        print('news headlines format:', train_y.shape)
        print('news descriptions format:', train_x.shape)
        print('number of tokens in the vocabulary', FLAGS.vocab_len)
        print("huffman tree max depth ", max([len(path) for path in htree.values()]))

        Model = SummarizationModel(FLAGS, idx2word, htree)

        if args.train is not None:
                # Trains the model
                Model.train(train_x, train_y, val_x, val_y)

        else:
                # Evaluates the model
                Model.eval(test_x, test_y)
Example #2
0
def main(args):
    main_start = time.time()

    tf.set_random_seed(2019)
    random.seed(2019)
    np.random.seed(2019)

    if len(args) != 1:
        raise Exception('Problem with flags: %s' % args)

    # Correcting a few flags for test/eval mode.
    if FLAGS.mode != 'train':
        FLAGS.batch_size = FLAGS.beam_size
        FLAGS.bs_dec_steps = FLAGS.dec_steps

        if FLAGS.model.lower() != "tx":
            FLAGS.dec_steps = 1

    assert FLAGS.mode == 'train' or FLAGS.batch_size == FLAGS.beam_size, \
        "In test mode, batch size should be equal to beam size."

    assert FLAGS.mode == 'train' or FLAGS.dec_steps == 1 or FLAGS.model.lower() == "tx", \
        "In test mode, no. of decoder steps should be one."

    os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
    os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(
        str(gpu_id) for gpu_id in FLAGS.GPUs)

    if not os.path.exists(FLAGS.PathToCheckpoint):
        os.makedirs(FLAGS.PathToCheckpoint)

    if FLAGS.mode == "test" and not os.path.exists(FLAGS.PathToResults):
        os.makedirs(FLAGS.PathToResults)
        os.makedirs(FLAGS.PathToResults + 'predictions')
        os.makedirs(FLAGS.PathToResults + 'groundtruths')

    if FLAGS.mode == 'eval':
        eval_model(FLAGS.PathToResults)
    else:
        start = time.time()
        vocab = Vocab(max_vocab_size=FLAGS.vocab_size,
                      emb_dim=FLAGS.dim,
                      dataset_path=FLAGS.PathToDataset,
                      glove_path=FLAGS.PathToGlove,
                      vocab_path=FLAGS.PathToVocab,
                      lookup_path=FLAGS.PathToLookups)

        if FLAGS.model.lower() == "plain":
            print("Setting up the plain model.\n")
            data = DataGenerator(path_to_dataset=FLAGS.PathToDataset,
                                 max_inp_seq_len=FLAGS.enc_steps,
                                 max_out_seq_len=FLAGS.dec_steps,
                                 vocab=vocab,
                                 use_pgen=FLAGS.use_pgen,
                                 use_sample=FLAGS.sample)
            summarizer = SummarizationModel(vocab, data)

        elif FLAGS.model.lower() == "hier":
            print("Setting up the hier model.\n")
            data = DataGeneratorHier(
                path_to_dataset=FLAGS.PathToDataset,
                max_inp_sent=FLAGS.max_enc_sent,
                max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent,
                max_out_tok=FLAGS.dec_steps,
                vocab=vocab,
                use_pgen=FLAGS.use_pgen,
                use_sample=FLAGS.sample)
            summarizer = SummarizationModelHier(vocab, data)

        elif FLAGS.model.lower() == "rlhier":
            print("Setting up the Hier RL model.\n")
            data = DataGeneratorHier(
                path_to_dataset=FLAGS.PathToDataset,
                max_inp_sent=FLAGS.max_enc_sent,
                max_inp_tok_per_sent=FLAGS.max_enc_steps_per_sent,
                max_out_tok=FLAGS.dec_steps,
                vocab=vocab,
                use_pgen=FLAGS.use_pgen,
                use_sample=FLAGS.sample)
            summarizer = SummarizationModelHierSC(vocab, data)

        else:
            raise ValueError(
                "model flag should be either of plain/hier/bayesian/shared!! \n"
            )

        end = time.time()
        print(
            "Setting up vocab, data and model took {:.2f} sec.".format(end -
                                                                       start))

        summarizer.build_graph()

        if FLAGS.mode == 'train':
            summarizer.train()
        elif FLAGS.mode == "test":
            summarizer.test()
        else:
            raise ValueError("mode should be either train/test!! \n")

        main_end = time.time()
        print("Total time elapsed: %.2f \n" % (main_end - main_start))