def output_to_classification_batch(output,batch, batcher, cla_batcher,cc): example_list =[] bleu =[] for i in range(FLAGS.batch_size): decoded_words_all = [] output_ids = [int(t) for t in output[i]] decoded_words = data.outputids2words(output_ids, batcher._vocab, None) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words decoded_words_all = ' '.join(decoded_words).strip() # single string decoded_words_all = decoded_words_all.replace("[UNK] ", "") decoded_words_all = decoded_words_all.replace("[UNK]", "") if decoded_words_all.strip() == "": bleu.append(0) new_dis_example = bc.Example(".", batch.score, cla_batcher._vocab, cla_batcher._hps) else: bleu.append(sentence_bleu([batch.original_reviews[i].split()],decoded_words_all.split(),smoothing_function=cc.method1)) new_dis_example = bc.Example(decoded_words_all, batch.score, cla_batcher._vocab, cla_batcher._hps) example_list.append(new_dis_example) return bc.Batch(example_list, cla_batcher._hps, cla_batcher._vocab), bleu
def batch_classification_batch(batch, batcher, cla_batcher): db_example_list = [] for i in range(FLAGS.batch_size): original_text = batch.original_reviews[i].split() if len(original_text) > batcher._hps.max_enc_steps: #: original_text = original_text[:batcher._hps.max_enc_steps] new_original_text = [] for j in range(len(original_text)): if batch.weight[i][j] >=1: new_original_text.append(original_text[j]) new_original_text = " ".join(new_original_text) if new_original_text.strip() =="": new_original_text = ". " new_dis_example = bc.Example(new_original_text, batch.score, cla_batcher._vocab, cla_batcher._hps) db_example_list.append(new_dis_example) return bc.Batch(db_example_list, cla_batcher._hps, cla_batcher._vocab)
def run_test_classification(self, data_path, model, batcher, sess): example_list = self.add_example_queue(data_path, batcher._vocab, batcher._hps) step = 0 right = 0 all = 0 while step < int(len(example_list) / FLAGS.batch_size): current_batch = bc.Batch( example_list[step * FLAGS.batch_size:(step + 1) * FLAGS.batch_size], batcher._hps, batcher._vocab) step += 1 right_s, number, error_list, error_label = model.run_eval_step( sess, current_batch) all += number right += right_s tf.logging.info(str("classification acc: ") + str(right / (all * 1.0)))
def generator_validation_negative_example(self, path, batcher, model_class, sess_cls, cla_batcher, mode): if not os.path.exists(path): os.mkdir(path) shutil.rmtree(path) if not os.path.exists(path): os.mkdir(path) counter = 0 step = 0 t0 = time.time() if mode == 'valid-transfer': batches = self.valid_transfer elif mode == 'test-transfer': batches = self.test_transfer print("len test", len(batches)) list_ref = [] list_pre = [] right = 0 all = 0 #while step < len(batches): while step < 310: cla_input = [] batch = batches[step] step += 1 decode_result = self._model.max_generator(self._sess, batch) example_list = [] #print("step", step) for i in range(FLAGS.batch_size): original_review = batch.original_reviews[i] # string score = batch.score output_ids = [int(t) for t in decode_result['generated'][i]][:] decoded_words = data.outputids2words(output_ids, self._vocab, None) # Remove the [STOP] token from decoded_words, if necessary try: fst_stop_idx = decoded_words.index( data.STOP_DECODING ) # index of the (first) [STOP] symbol decoded_words = decoded_words[:fst_stop_idx] except ValueError: decoded_words = decoded_words decoded_output = ' '.join(decoded_words) # single string self.write_negtive_to_json(original_review, decoded_output, score, counter, path) counter += 1 # this is how many examples we've decoded cla_input.append(decoded_output) ''' if len(original_review.split())>2 and len(decoded_output.split())>2: list_ref.append([original_review.split()]) list_pre.append(decoded_output.split()) ''' #bleu.append(sentence_bleu([batch.original_reviews[i]], decoded_words_all.split())) if decoded_output.strip() == "": decoded_output = ". " new_dis_example = bc.Example(decoded_output, batch.score, cla_batcher._vocab, cla_batcher._hps) example_list.append(new_dis_example) cla_batch = bc.Batch(example_list, cla_batcher._hps, cla_batcher._vocab) right_s, all_s, _, pre = model_class.run_eval_step( sess_cls, cla_batch) right += right_s all += all_s for i in range(FLAGS.batch_size): if len(batch.original_reviews[i].split()) > 2 and len( cla_input[i].split()) > 2 and batch.score == pre[i]: list_ref.append([batch.original_reviews[i].split()]) list_pre.append(cla_input[i].split()) transfer_acc = right * 1.0 / all * 100 bleu = corpus_bleu(list_ref, list_pre) * 100 tf.logging.info("valid transfer acc: " + str(transfer_acc)) tf.logging.info("BLEU: " + str(bleu)) return transfer_acc, bleu
def run_test_our_method(cla_batcher, cnn_classifier, sess_cnn, filename): test_to_true = read_test_result_our(filename, 1) test_to_false = read_test_result_our(filename, 0) gold_to_true = read_test_input(filename, 1) gold_to_false = read_test_input(filename, 0) list_ref = [] list_pre = [] right = 0 all = 0 right_cnn = 0 all_cnn = 0 for i in range(len(gold_to_true) // 64): example_list = [] for j in range(64): # example_list.append(test_true[i*64+j]) new_dis_example = bc.Example(test_to_true[i * 64 + j], 1, cla_batcher._vocab, cla_batcher._hps) # list_pre.append(test_false[i*64+j].split()) example_list.append(new_dis_example) # list_ref.append([gold_text[i*64+j].split()]) cla_batch = bc.Batch(example_list, cla_batcher._hps, cla_batcher._vocab) right_s, all_s, _, pre = cnn_classifier.run_eval_step( sess_cnn, cla_batch) right_cnn += right_s all_cnn += all_s for j in range(64): if len(gold_to_true[i * 64 + j].split()) > 2 and len( test_to_true[i * 64 + j].split()) > 2 and 1 == pre[j]: list_ref.append([gold_to_true[i * 64 + j].split()]) list_pre.append(test_to_true[i * 64 + j].split()) for i in range(len(gold_to_false) // 64): example_list = [] for j in range(64): # example_list.append(test_true[i*64+j]) new_dis_example = bc.Example(test_to_false[i * 64 + j], 0, cla_batcher._vocab, cla_batcher._hps) # list_pre.append(test_false[i*64+j].split()) example_list.append(new_dis_example) # list_ref.append([gold_text[i*64+j].split()]) cla_batch = bc.Batch(example_list, cla_batcher._hps, cla_batcher._vocab) right_s, all_s, _, pre = cnn_classifier.run_eval_step( sess_cnn, cla_batch) right_cnn += right_s all_cnn += all_s for j in range(64): if len(gold_to_false[i * 64 + j].split()) > 2 and len( test_to_false[i * 64 + j].split()) > 2 and 0 == pre[j]: list_ref.append([gold_to_false[i * 64 + j].split()]) list_pre.append(test_to_false[i * 64 + j].split()) tf.logging.info("cnn test acc: " + str(right_cnn * 1.0 / all_cnn)) cc = SmoothingFunction() tf.logging.info( "BLEU: " + str(corpus_bleu(list_ref, list_pre, smoothing_function=cc.method1)))