Ejemplo n.º 1
0
def decode():
    '''Load dictionaries'''
    # Load vocabularies.
    print(os.getcwd())
    doc_dict = data_util.load_dict(FLAGS.data_dir + "/doc_dict.txt")
    sum_dict = data_util.load_dict(FLAGS.data_dir + "/sum_dict.txt")
    if doc_dict is None or sum_dict is None:
        logging.warning("Dict not found.")
        print("Loading testing data")
    data = data_util.load_test_data(FLAGS.test_file, doc_dict)

    with tf.Session() as sess:
        # Create model and load parameters.
        logging.info("Creating %d layers of %d units." %
                     (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, True)

        result = []
        for idx, token_ids in enumerate(data):

            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, encoder_len, decoder_len =\
                model.get_batch(
                    {0: [(token_ids, [data_util.ID_GO, data_util.ID_EOS])]}, 0)

            if FLAGS.batch_size == 1 and FLAGS.geneos:
                print('ran code')
                loss, outputs = model.step(sess, encoder_inputs,
                                           decoder_inputs, encoder_len,
                                           decoder_len, True)

                outputs = [np.argmax(item) for item in outputs[0]]

            else:
                outputs = model.step_beam(sess,
                                          encoder_inputs,
                                          encoder_len,
                                          geneos=FLAGS.geneos)

            # If there is an EOS symbol in outputs, cut them at that point.
            if data_util.ID_EOS in outputs:
                outputs = outputs[:outputs.index(data_util.ID_EOS)]
            gen_sum = " ".join(data_util.sen_map2tok(outputs, sum_dict[1]))
            print(gen_sum)
            gen_sum = data_util.sen_postprocess(gen_sum)
            print(gen_sum)
            result.append(gen_sum)
            logging.info("Finish {} samples. :: {}".format(idx, gen_sum[:75]))
        with open(FLAGS.test_output, "w") as f:
            for item in result:
                print(item, file=f)
Ejemplo n.º 2
0
def decode():
    # Load vocabularies.
    doc_dict = data_util.load_dict(FLAGS.data_dir + "/doc_dict.txt")
    sum_dict = data_util.load_dict(FLAGS.data_dir + "/sum_dict.txt")
    en_dict = data_util.load_dict(FLAGS.data_dir + "/en_dict.txt")
    if doc_dict is None or sum_dict is None:
        logging.warning("Dict not found.")
    data, en_data = data_util.load_test_data(
        FLAGS.test_file, doc_dict, FLAGS.data_dir + "/test.entity.txt",
        en_dict)

    with tf.Session() as sess:
        # Create model and load parameters.
        logging.info("Creating %d layers of %d units." %
                     (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, True, None, None, None)

        result = []
        for idx, token_ids in enumerate(data):
            en_ids = en_data[idx]
            if len(en_ids) == 0:
                en_ids = [data_util.ID_PAD]


#            token_ids, en_ids = d
#print(idx)
#print(token_ids)

# Get a 1-element batch to feed the sentence to the model.
            shiva = model.get_batch(
                {
                    0:
                    [(token_ids, [data_util.ID_GO, data_util.ID_EOS],
                      [data_util.ID_PAD, data_util.ID_PAD, data_util.ID_PAD] +
                      en_ids +
                      [data_util.ID_PAD, data_util.ID_PAD, data_util.ID_PAD])]
                }, 0)
            #print(shiva)
            encoder_inputs, decoder_inputs, encoder_len, decoder_len, entity_inputs, entity_len = shiva
            K = min(FLAGS.K, np.amax(entity_len) - 6)
            #print("K", K)

            if FLAGS.batch_size == 1 and FLAGS.geneos:
                loss, outputs, att, t = model.step(sess, encoder_inputs,
                                                   decoder_inputs,
                                                   entity_inputs, encoder_len,
                                                   decoder_len, entity_len, K,
                                                   True)

                #outputs = [np.argmax(item) for item in outputs[0]]
            else:
                outputs = model.step_beam(sess,
                                          encoder_inputs,
                                          encoder_len,
                                          entity_inputs,
                                          entity_len,
                                          K,
                                          geneos=FLAGS.geneos)

            # If there is an EOS symbol in outputs, cut them at that point.
            #print(outputs)
            f2 = open(FLAGS.test_output + '.disambig', 'a')
            f2.write(' '.join(
                str(y) + ":" + str(x.mean())
                for x, y in zip(t[0], entity_inputs[0][3:])) + '\n')
            f2.close()
            f2 = open(FLAGS.test_output + '.attention', 'a')
            f2.write(' '.join(
                str(y) + ":" + str(x)
                for x, y in zip(att[0], entity_inputs[0][3:])) + '\n')
            f2.close()
            outputs = list(outputs[0])
            if data_util.ID_EOS in output:
                outputs = outputs[:outputs.index(data_util.ID_EOS)]
            #outputs = list(outputs)
            gen_sum = " ".join(data_util.sen_map2tok(
                outputs, sum_dict[1]))  #sum_dict[1])) #lvt_str
            gen_sum = data_util.sen_postprocess(gen_sum)
            result.append(gen_sum)
            logging.info("Finish {} samples. :: {}".format(idx, gen_sum[:75]))

        with open(FLAGS.test_output, "w") as f:
            for item in result:
                print(item, file=f)
Ejemplo n.º 3
0
def decode():
    # Load vocabularies.
    doc_dict = data_util.load_dict(FLAGS.data_dir + "/doc_dict.txt")
    sum_dict = data_util.load_dict(FLAGS.data_dir + "/sum_dict.txt")
    if doc_dict is None or sum_dict is None:
        logging.warning("Dict not found.")
    docs, data = data_util.load_test_data(FLAGS.test_file, doc_dict)

    with tf.Session() as sess:
        # Create model and load parameters.
        logging.info("Creating %d layers of %d units." %
                     (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, True)
        class_model = create_class_model(sess, True)

        result = []
        for idx, token_ids in enumerate(data):

            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, encoder_len, decoder_len, class_output, class_len =\
                data_util.get_batch(
                    {0: [(token_ids, [data_util.ID_GO, data_util.ID_EOS],[0,0])]}, _buckets, 0, FLAGS.batch_size, False, 0)

            if FLAGS.batch_size == 1 and FLAGS.geneos:
                loss, outputs = model.step(sess, encoder_inputs,
                                           decoder_inputs, encoder_len,
                                           decoder_len, True)

                outputs = [np.argmax(item) for item in outputs[0]]
            else:
                outputs = model.step_beam(sess,
                                          encoder_inputs,
                                          encoder_len,
                                          geneos=FLAGS.geneos)

            # If there is an EOS symbol in outputs, cut them at that point.
            if data_util.ID_EOS in outputs:
                outputs = outputs[:outputs.index(data_util.ID_EOS)]
            gen_sum = " ".join(data_util.sen_map2tok(outputs, sum_dict[1]))
            gen_sum = data_util.sen_postprocess(gen_sum)
            result.append(gen_sum)
            logging.info("Finish {} samples. :: {}".format(idx, gen_sum[:75]))

        #Get Encoder outputs
        batchidx = 0
        final_inputs = []
        final_outputs = []
        final_len = []
        while batchidx + FLAGS.batch_size <= len(data):
            encoder_inputs, decoder_inputs, encoder_len, decoder_len, class_output, class_len =\
            data_util.get_batch(
                {0: [(token_ids, [data_util.ID_GO, data_util.ID_EOS],[0,0])]}, _buckets, 0, FLAGS.batch_size, False, 0)

            _, _, enc_outputs = model.step(sess, encoder_inputs,
                                           decoder_inputs, encoder_len,
                                           decoder_len, True)

            enc_outputs = data_util.add_pad_for_hidden(enc_outputs,
                                                       _buckets[0][0])

            final_inputs.append(enc_outputs)
            final_outputs.append(class_output)
            final_len.append(class_len)

            batchidx += FLAGS.batch_size

        final_inputs = np.asarray(final_inputs)
        final_inputs = np.concatenate(final_inputs, 0)
        final_outputs = np.asarray(final_outputs)
        final_outputs = np.concatenate(final_outputs, 0)
        final_len = np.asarray(final_len)
        final_len = np.concatenate(final_len, 0)
        print(final_inputs.shape, final_outputs.shape, final_len.shape)

        #Hidden classifier
        step_loss, output = class_model.step(sess, final_inputs[:],
                                             final_outputs[:], final_len[:],
                                             True)

        clipped = np.array(output > 0.5, dtype=np.int)
        #label = data_util.hidden_label_gen(FLAGS.test_file, "data/test.1981.msg.txt")
        #make confusion matrix to get precision
        #tn, fp, fn, tp = confusion_matrix(label.flatten(), clipped.flatten()).ravel()
        #print("Test precision : ", tp/(tp+fp))

        with open(FLAGS.test_output, "w") as f:
            for idx, item in enumerate(result):
                print(item, file=f)
                for j in range(len(docs[idx])):
                    if clipped[idx][j] == 1:
                        print("Recommended identifier: " + docs[idx][j] + " ",
                              file=f)
                print("\n", file=f)