def decode(): '''Load dictionaries''' # Load vocabularies. print(os.getcwd()) doc_dict = data_util.load_dict(FLAGS.data_dir + "/doc_dict.txt") sum_dict = data_util.load_dict(FLAGS.data_dir + "/sum_dict.txt") if doc_dict is None or sum_dict is None: logging.warning("Dict not found.") print("Loading testing data") data = data_util.load_test_data(FLAGS.test_file, doc_dict) with tf.Session() as sess: # Create model and load parameters. logging.info("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) model = create_model(sess, True) result = [] for idx, token_ids in enumerate(data): # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, encoder_len, decoder_len =\ model.get_batch( {0: [(token_ids, [data_util.ID_GO, data_util.ID_EOS])]}, 0) if FLAGS.batch_size == 1 and FLAGS.geneos: print('ran code') loss, outputs = model.step(sess, encoder_inputs, decoder_inputs, encoder_len, decoder_len, True) outputs = [np.argmax(item) for item in outputs[0]] else: outputs = model.step_beam(sess, encoder_inputs, encoder_len, geneos=FLAGS.geneos) # If there is an EOS symbol in outputs, cut them at that point. if data_util.ID_EOS in outputs: outputs = outputs[:outputs.index(data_util.ID_EOS)] gen_sum = " ".join(data_util.sen_map2tok(outputs, sum_dict[1])) print(gen_sum) gen_sum = data_util.sen_postprocess(gen_sum) print(gen_sum) result.append(gen_sum) logging.info("Finish {} samples. :: {}".format(idx, gen_sum[:75])) with open(FLAGS.test_output, "w") as f: for item in result: print(item, file=f)
def decode(): # Load vocabularies. doc_dict = data_util.load_dict(FLAGS.data_dir + "/doc_dict.txt") sum_dict = data_util.load_dict(FLAGS.data_dir + "/sum_dict.txt") en_dict = data_util.load_dict(FLAGS.data_dir + "/en_dict.txt") if doc_dict is None or sum_dict is None: logging.warning("Dict not found.") data, en_data = data_util.load_test_data( FLAGS.test_file, doc_dict, FLAGS.data_dir + "/test.entity.txt", en_dict) with tf.Session() as sess: # Create model and load parameters. logging.info("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) model = create_model(sess, True, None, None, None) result = [] for idx, token_ids in enumerate(data): en_ids = en_data[idx] if len(en_ids) == 0: en_ids = [data_util.ID_PAD] # token_ids, en_ids = d #print(idx) #print(token_ids) # Get a 1-element batch to feed the sentence to the model. shiva = model.get_batch( { 0: [(token_ids, [data_util.ID_GO, data_util.ID_EOS], [data_util.ID_PAD, data_util.ID_PAD, data_util.ID_PAD] + en_ids + [data_util.ID_PAD, data_util.ID_PAD, data_util.ID_PAD])] }, 0) #print(shiva) encoder_inputs, decoder_inputs, encoder_len, decoder_len, entity_inputs, entity_len = shiva K = min(FLAGS.K, np.amax(entity_len) - 6) #print("K", K) if FLAGS.batch_size == 1 and FLAGS.geneos: loss, outputs, att, t = model.step(sess, encoder_inputs, decoder_inputs, entity_inputs, encoder_len, decoder_len, entity_len, K, True) #outputs = [np.argmax(item) for item in outputs[0]] else: outputs = model.step_beam(sess, encoder_inputs, encoder_len, entity_inputs, entity_len, K, geneos=FLAGS.geneos) # If there is an EOS symbol in outputs, cut them at that point. #print(outputs) f2 = open(FLAGS.test_output + '.disambig', 'a') f2.write(' '.join( str(y) + ":" + str(x.mean()) for x, y in zip(t[0], entity_inputs[0][3:])) + '\n') f2.close() f2 = open(FLAGS.test_output + '.attention', 'a') f2.write(' '.join( str(y) + ":" + str(x) for x, y in zip(att[0], entity_inputs[0][3:])) + '\n') f2.close() outputs = list(outputs[0]) if data_util.ID_EOS in output: outputs = outputs[:outputs.index(data_util.ID_EOS)] #outputs = list(outputs) gen_sum = " ".join(data_util.sen_map2tok( outputs, sum_dict[1])) #sum_dict[1])) #lvt_str gen_sum = data_util.sen_postprocess(gen_sum) result.append(gen_sum) logging.info("Finish {} samples. :: {}".format(idx, gen_sum[:75])) with open(FLAGS.test_output, "w") as f: for item in result: print(item, file=f)
def decode(): # Load vocabularies. doc_dict = data_util.load_dict(FLAGS.data_dir + "/doc_dict.txt") sum_dict = data_util.load_dict(FLAGS.data_dir + "/sum_dict.txt") if doc_dict is None or sum_dict is None: logging.warning("Dict not found.") docs, data = data_util.load_test_data(FLAGS.test_file, doc_dict) with tf.Session() as sess: # Create model and load parameters. logging.info("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) model = create_model(sess, True) class_model = create_class_model(sess, True) result = [] for idx, token_ids in enumerate(data): # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, encoder_len, decoder_len, class_output, class_len =\ data_util.get_batch( {0: [(token_ids, [data_util.ID_GO, data_util.ID_EOS],[0,0])]}, _buckets, 0, FLAGS.batch_size, False, 0) if FLAGS.batch_size == 1 and FLAGS.geneos: loss, outputs = model.step(sess, encoder_inputs, decoder_inputs, encoder_len, decoder_len, True) outputs = [np.argmax(item) for item in outputs[0]] else: outputs = model.step_beam(sess, encoder_inputs, encoder_len, geneos=FLAGS.geneos) # If there is an EOS symbol in outputs, cut them at that point. if data_util.ID_EOS in outputs: outputs = outputs[:outputs.index(data_util.ID_EOS)] gen_sum = " ".join(data_util.sen_map2tok(outputs, sum_dict[1])) gen_sum = data_util.sen_postprocess(gen_sum) result.append(gen_sum) logging.info("Finish {} samples. :: {}".format(idx, gen_sum[:75])) #Get Encoder outputs batchidx = 0 final_inputs = [] final_outputs = [] final_len = [] while batchidx + FLAGS.batch_size <= len(data): encoder_inputs, decoder_inputs, encoder_len, decoder_len, class_output, class_len =\ data_util.get_batch( {0: [(token_ids, [data_util.ID_GO, data_util.ID_EOS],[0,0])]}, _buckets, 0, FLAGS.batch_size, False, 0) _, _, enc_outputs = model.step(sess, encoder_inputs, decoder_inputs, encoder_len, decoder_len, True) enc_outputs = data_util.add_pad_for_hidden(enc_outputs, _buckets[0][0]) final_inputs.append(enc_outputs) final_outputs.append(class_output) final_len.append(class_len) batchidx += FLAGS.batch_size final_inputs = np.asarray(final_inputs) final_inputs = np.concatenate(final_inputs, 0) final_outputs = np.asarray(final_outputs) final_outputs = np.concatenate(final_outputs, 0) final_len = np.asarray(final_len) final_len = np.concatenate(final_len, 0) print(final_inputs.shape, final_outputs.shape, final_len.shape) #Hidden classifier step_loss, output = class_model.step(sess, final_inputs[:], final_outputs[:], final_len[:], True) clipped = np.array(output > 0.5, dtype=np.int) #label = data_util.hidden_label_gen(FLAGS.test_file, "data/test.1981.msg.txt") #make confusion matrix to get precision #tn, fp, fn, tp = confusion_matrix(label.flatten(), clipped.flatten()).ravel() #print("Test precision : ", tp/(tp+fp)) with open(FLAGS.test_output, "w") as f: for idx, item in enumerate(result): print(item, file=f) for j in range(len(docs[idx])): if clipped[idx][j] == 1: print("Recommended identifier: " + docs[idx][j] + " ", file=f) print("\n", file=f)