def do_evalb(model_dev, model, sess, dev_set): # Load vocabularies. eval_batch_size = 50 sents_vocab_path = os.path.join(FLAGS.data_dir,"vocab%d.sents" % FLAGS.input_vocab_size) parse_vocab_path = os.path.join(FLAGS.data_dir,"vocab%d.parse" % FLAGS.output_vocab_size) sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(sents_vocab_path) _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path) gold_file_name = os.path.join(FLAGS.train_dir, 'dev.gold.txt') # file with matched brackets decoded_br_file_name = os.path.join(FLAGS.train_dir, 'dev.decoded.br.txt') # file filler XX help as well decoded_mx_file_name = os.path.join(FLAGS.train_dir, 'dev.decoded.mx.txt') print("Doing evalb") print("Debug - step: ", model.global_step.eval()) num_sents = [] num_valid_br = [] num_valid_mx = [] br = [] mx = [] for bucket_id in xrange(len(_buckets)): bucket_size = len(dev_set[bucket_id]) offsets = np.arange(0, bucket_size, eval_batch_size) for batch_offset in offsets: fout_gold = open(gold_file_name, 'w') fout_br = open(decoded_br_file_name, 'w') fout_mx = open(decoded_mx_file_name, 'w') all_examples = dev_set[bucket_id][batch_offset:batch_offset + eval_batch_size] model_dev.batch_size = len(all_examples) token_ids = [x[0] for x in all_examples] gold_ids = [x[1] for x in all_examples] dec_ids = [[]] * len(token_ids) encoder_inputs, decoder_inputs, target_weights, seq_len = model_dev.get_decode_batch( {bucket_id: zip(token_ids, dec_ids)}, bucket_id) _, _, output_logits = model_dev.step(sess, encoder_inputs, decoder_inputs, target_weights, seq_len, bucket_id, True) outputs = [np.argmax(logit, axis=1) for logit in output_logits] to_decode = np.array(outputs).T num_valid = 0 num_sents.append(to_decode.shape[0]) for sent_id in range(to_decode.shape[0]): parse = list(to_decode[sent_id, :]) if data_utils.EOS_ID in parse: parse = parse[:parse.index(data_utils.EOS_ID)] # raw decoded parse # print(parse) decoded_parse = [] for output in parse: if output < len(rev_parse_vocab): decoded_parse.append(tf.compat.as_str(rev_parse_vocab[output])) else: decoded_parse.append("_UNK") # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse] # add brackets for tree balance parse_br, valid = add_brackets(decoded_parse) num_valid += valid # get gold parse, gold sentence gold_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id]] sent_text = [tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id]] # parse with also matching "XX" length parse_mx = match_length(parse_br, sent_text) to_write_gold = merge_sent_tree(gold_parse[:-1], sent_text) # account for EOS to_write_br = merge_sent_tree(parse_br, sent_text) to_write_mx = merge_sent_tree(parse_mx, sent_text) fout_gold.write('{}\n'.format(' '.join(to_write_gold))) fout_br.write('{}\n'.format(' '.join(to_write_br))) fout_mx.write('{}\n'.format(' '.join(to_write_mx))) # call evalb fout_gold.close() fout_br.close() fout_mx.close() # evaluate current batch cmd = [evalb_path, '-p', prm_file, gold_file_name, decoded_br_file_name] p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() out_lines = out.split("\n") vv = [x for x in out_lines if "Number of Valid sentence " in x] s1 = float(vv[0].split()[-1]) num_valid_br.append(s1) m_br, g_br, t_br = process_eval(out_lines, to_decode.shape[0]) br.append([m_br, g_br, t_br]) cmd = [evalb_path, '-p', prm_file, gold_file_name, decoded_mx_file_name] p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() out_lines = out.split("\n") vv = [x for x in out_lines if "Number of Valid sentence " in x] s2 = float(vv[0].split()[-1]) num_valid_mx.append(s2) m_mx, g_mx, t_mx = process_eval(out_lines, to_decode.shape[0]) mx.append([m_mx, g_mx, t_mx]) br_all = np.array(br) mx_all = np.array(mx) sum_br_pre = sum(br_all[:,0]) / sum(br_all[:,2]) sum_br_rec = sum(br_all[:,0]) / sum(br_all[:,1]) sum_br_f1 = 2 * sum_br_pre*sum_br_rec / (sum_br_rec + sum_br_pre) sum_mx_pre = sum(mx_all[:,0]) / sum(mx_all[:,2]) sum_mx_rec = sum(mx_all[:,0]) / sum(mx_all[:,1]) sum_mx_f1 = 2 * sum_mx_pre*sum_mx_rec / (sum_mx_rec + sum_mx_pre) br_valid = sum(num_valid_br) mx_valid = sum(num_valid_mx) print("Bracket only -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" %(br_valid, sum_br_pre, sum_br_rec, sum_br_f1) ) print("Matched XX -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" %(mx_valid, sum_mx_pre, sum_mx_rec, sum_mx_f1) ) if br_valid >= 5000 and sum_br_f1 >= 80: print("Very good model") checkpoint_path = os.path.join(FLAGS.train_dir, "good-model.ckpt") model.saver.save(sess, checkpoint_path, global_step=model.global_step)
def write_decode(model_dev, sess, dev_set, eval_batch_size, globstep, eval_now=False): # Load vocabularies. sents_vocab_path = os.path.join(FLAGS.data_dir, FLAGS.source_vocab_file) parse_vocab_path = os.path.join(FLAGS.data_dir, FLAGS.target_vocab_file) sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary( sents_vocab_path) _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path) # current progress stepname = str(globstep) gold_file_name = os.path.join(FLAGS.train_dir, 'gold-step' + stepname + '.txt') print(gold_file_name) # file with matched brackets decoded_br_file_name = os.path.join(FLAGS.train_dir, 'decoded-br-step' + stepname + '.txt') # file filler XX help as well decoded_mx_file_name = os.path.join(FLAGS.train_dir, 'decoded-mx-step' + stepname + '.txt') fout_gold = open(gold_file_name, 'w') fout_br = open(decoded_br_file_name, 'w') fout_mx = open(decoded_mx_file_name, 'w') num_dev_sents = 0 for bucket_id in xrange(len(_buckets)): bucket_size = len(dev_set[bucket_id]) offsets = np.arange(0, bucket_size, eval_batch_size) for batch_offset in offsets: all_examples = dev_set[bucket_id][batch_offset:batch_offset + eval_batch_size] model_dev.batch_size = len(all_examples) token_ids = [x[0] for x in all_examples] partition = [x[2] for x in all_examples] speech_feats = [x[3] for x in all_examples] pbfs = [x[4] for x in all_examples] pafs = [x[5] for x in all_examples] gold_ids = [x[1] for x in all_examples] dec_ids = [[]] * len(token_ids) text_encoder_inputs, speech_encoder_inputs, pause_bef, pause_aft, \ decoder_inputs, target_weights, \ text_seq_len, speech_seq_len = model_dev.get_batch(\ {bucket_id: zip(token_ids, dec_ids, partition, speech_feats, pbfs, pafs)}, \ bucket_id, batch_offset, FLAGS.use_speech) _, _, output_logits = model_dev.step(sess, [text_encoder_inputs, speech_encoder_inputs, \ pause_bef, pause_aft],\ decoder_inputs, target_weights, text_seq_len, speech_seq_len, \ bucket_id, True, FLAGS.use_speech) outputs = [np.argmax(logit, axis=1) for logit in output_logits] to_decode = np.array(outputs).T num_dev_sents += to_decode.shape[0] num_valid = 0 for sent_id in range(to_decode.shape[0]): parse = list(to_decode[sent_id, :]) if data_utils.EOS_ID in parse: parse = parse[:parse.index(data_utils.EOS_ID)] decoded_parse = [] for output in parse: if output < len(rev_parse_vocab): decoded_parse.append( tf.compat.as_str(rev_parse_vocab[output])) else: decoded_parse.append("_UNK") # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse] # add brackets for tree balance parse_br, valid = add_brackets(decoded_parse) num_valid += valid # get gold parse, gold sentence gold_parse = [ tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id] ] sent_text = [ tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id] ] # parse with also matching "XX" length parse_mx = match_length(parse_br, sent_text) parse_mx = delete_empty_constituents(parse_mx) to_write_gold = merge_sent_tree(gold_parse[:-1], sent_text) # account for EOS to_write_br = merge_sent_tree(parse_br, sent_text) to_write_mx = merge_sent_tree(parse_mx, sent_text) fout_gold.write('{}\n'.format(' '.join(to_write_gold))) fout_br.write('{}\n'.format(' '.join(to_write_br))) fout_mx.write('{}\n'.format(' '.join(to_write_mx))) # Write to file fout_gold.close() fout_br.close() fout_mx.close() f_score_mx = -1.0 if eval_now: f_score_mx = 0.0 correction_types = ["Bracket only", "Matched XX"] corrected_files = [decoded_br_file_name, decoded_mx_file_name] for c_type, c_file in zip(correction_types, corrected_files): cmd = [evalb_path, '-p', prm_file, gold_file_name, c_file] p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() out_lines = out.split("\n") vv = [x for x in out_lines if "Number of Valid sentence " in x] if len(vv) == 0: return 0.0 s1 = float(vv[0].split()[-1]) m_br, g_br, t_br = process_eval(out_lines, num_dev_sents) try: recall = float(m_br) / float(g_br) prec = float(m_br) / float(t_br) f_score = 2 * recall * prec / (recall + prec) except ZeroDivisionError as e: recall, prec, f_score = 0.0, 0.0, 0.0 print("%s -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" \ %(c_type, s1, prec, recall, f_score) ) sys.stdout.flush() if "XX" in c_type: f_score_mx = f_score return f_score_mx
def write_decode(model_dev, sess, dev_set, eval_batch_size, globstep): # Load vocabularies. stepname = str(globstep) gold_file_name = os.path.join(FLAGS.train_dir, 'gold-step' + stepname + '.txt') print(gold_file_name) # file with matched brackets decoded_br_file_name = os.path.join(FLAGS.train_dir, 'decoded-br-step' + stepname + '.txt') # file filler XX help as well decoded_mx_file_name = os.path.join(FLAGS.train_dir, 'decoded-mx-step' + stepname + '.txt') fout_gold = open(gold_file_name, 'w') fout_br = open(decoded_br_file_name, 'w') fout_mx = open(decoded_mx_file_name, 'w') for bucket_id in xrange(len(_buckets)): bucket_size = len(dev_set[bucket_id]) offsets = np.arange(0, bucket_size, eval_batch_size) for batch_offset in offsets: all_examples = dev_set[bucket_id][batch_offset:batch_offset + eval_batch_size] model_dev.batch_size = len(all_examples) token_ids = [x[0] for x in all_examples] mfccs = [x[2] for x in all_examples] gold_ids = [x[1] for x in all_examples] dec_ids = [[]] * len(token_ids) text_encoder_inputs, speech_encoder_inputs, decoder_inputs, target_weights, seq_len = model_dev.get_batch( {bucket_id: zip(token_ids, dec_ids, mfccs)}, bucket_id) _, _, output_logits = model_dev.step( sess, [text_encoder_inputs, speech_encoder_inputs], decoder_inputs, target_weights, seq_len, bucket_id, True) outputs = [np.argmax(logit, axis=1) for logit in output_logits] to_decode = np.array(outputs).T num_valid = 0 for sent_id in range(to_decode.shape[0]): parse = list(to_decode[sent_id, :]) if data_utils.EOS_ID in parse: parse = parse[:parse.index(data_utils.EOS_ID)] # raw decoded parse # print(parse) decoded_parse = [] for output in parse: if output < len(rev_parse_vocab): decoded_parse.append( tf.compat.as_str(rev_parse_vocab[output])) else: decoded_parse.append("_UNK") # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse] # add brackets for tree balance parse_br, valid = add_brackets(decoded_parse) num_valid += valid # get gold parse, gold sentence gold_parse = [ tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id] ] sent_text = [ tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id] ] # parse with also matching "XX" length parse_mx = match_length(parse_br, sent_text) parse_mx = delete_empty_constituents(parse_mx) to_write_gold = merge_sent_tree(gold_parse, sent_text) # account for EOS to_write_br = merge_sent_tree(parse_br, sent_text) to_write_mx = merge_sent_tree(parse_mx, sent_text) fout_gold.write('{}\n'.format(' '.join(to_write_gold))) fout_br.write('{}\n'.format(' '.join(to_write_br))) fout_mx.write('{}\n'.format(' '.join(to_write_mx))) # Write to file fout_gold.close() fout_br.close() fout_mx.close()
def write_decode(model_dev, sess, dev_set, get_results=False): """Perform evaluatio.""" # Load vocabularies. sents_vocab_path = os.path.join(FLAGS.vocab_dir, FLAGS.source_vocab_file['word']) parse_vocab_path = os.path.join(FLAGS.vocab_dir, FLAGS.target_vocab_file) sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary( sents_vocab_path) _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path) gold_file_name = os.path.join(FLAGS.train_dir, 'gold.txt') # file with matched brackets decoded_br_file_name = os.path.join(FLAGS.train_dir, 'decoded.br.txt') # file filler XX help as well decoded_mx_file_name = os.path.join(FLAGS.train_dir, 'decoded.mx.txt') sent_id_file_name = os.path.join(FLAGS.train_dir, 'sent_id.txt') fout_gold = open(gold_file_name, 'w') fout_br = open(decoded_br_file_name, 'w') fout_mx = open(decoded_mx_file_name, 'w') fsent_id = open(sent_id_file_name, 'w') num_dev_sents = 0 for bucket_id in xrange(len(_buckets)): bucket_size = len(dev_set[bucket_id]) offsets = np.arange(0, bucket_size, FLAGS.batch_size) for batch_offset in offsets: all_examples = dev_set[bucket_id][batch_offset:batch_offset + FLAGS.batch_size] model_dev.batch_size = len(all_examples) sent_ids = [x[0] for x in all_examples] for sent_id in sent_ids: fsent_id.write(sent_id + "\n") token_ids = [x[1] for x in all_examples] gold_ids = [x[2] for x in all_examples] dec_ids = [[]] * len(token_ids) encoder_inputs, seq_len, decoder_inputs, seq_len_target =\ model_dev.get_batch({bucket_id: zip(sent_ids, token_ids, dec_ids)}) output_logits = model_dev.step(sess, encoder_inputs, seq_len, decoder_inputs, seq_len_target) outputs = np.argmax(output_logits, axis=1) outputs = np.reshape( outputs, (max(seq_len_target), model_dev.batch_size)) # T*B to_decode = np.array(outputs).T num_dev_sents += to_decode.shape[0] for sent_id in range(to_decode.shape[0]): parse = list(to_decode[sent_id, :]) if data_utils.EOS_ID in parse: parse = parse[:parse.index(data_utils.EOS_ID)] decoded_parse = [] for output in parse: if output < len(rev_parse_vocab): decoded_parse.append( tf.compat.as_str(rev_parse_vocab[output])) else: decoded_parse.append("_UNK") parse_br, valid = add_brackets(decoded_parse) # get gold parse, gold sentence gold_parse = [ tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id] ] sent_text = [ tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id]['word'] ] # parse with also matching "XX" length parse_mx = match_length(parse_br, sent_text) parse_mx = delete_empty_constituents(parse_mx) # account for EOS to_write_gold = merge_sent_tree(gold_parse[:-1], sent_text) to_write_br = merge_sent_tree(parse_br, sent_text) to_write_mx = merge_sent_tree(parse_mx, sent_text) fout_gold.write('{}\n'.format(' '.join(to_write_gold))) fout_br.write('{}\n'.format(' '.join(to_write_br))) fout_mx.write('{}\n'.format(' '.join(to_write_mx))) # Write to file fout_gold.close() fout_br.close() fout_mx.close() fsent_id.close() f_score_mx = 0.0 correction_types = ["Bracket only", "Matched XX"] corrected_files = [decoded_br_file_name, decoded_mx_file_name] for c_type, c_file in zip(correction_types, corrected_files): cmd = [evalb_path, '-p', prm_file, gold_file_name, c_file] p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() with open(os.path.join(FLAGS.train_dir, "log.txt"), "w") as log_f: log_f.write(out) out_lines = out.split("\n") vv = [x for x in out_lines if "Number of Valid sentence " in x] if len(vv) == 0: return 0.0 s1 = float(vv[0].split()[-1]) m_br, g_br, t_br = process_eval(out_lines, num_dev_sents) try: recall = float(m_br) / float(g_br) prec = float(m_br) / float(t_br) f_score = 2 * recall * prec / (recall + prec) except ZeroDivisionError: recall, prec, f_score = 0.0, 0.0, 0.0 print("%s -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" % (c_type, s1, prec, recall, f_score)) if "XX" in c_type: f_score_mx = f_score return f_score_mx
def write_decode(model_dev, sess, dev_set): # Load vocabularies. sents_vocab_path = os.path.join(data_dir,"vocab%d.sents" % 90000) parse_vocab_path = os.path.join(data_dir,"vocab%d.parse" % 128) sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(sents_vocab_path) _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path) gold_file_name = os.path.join(train_dir, 'debug.gold.txt') # file with matched brackets decoded_br_file_name = os.path.join(train_dir, 'debug.decoded.br.txt') # file filler XX help as well decoded_mx_file_name = os.path.join(train_dir, 'debug.decoded.mx.txt') fout_gold = open(gold_file_name, 'w') fout_br = open(decoded_br_file_name, 'w') fout_mx = open(decoded_mx_file_name, 'w') fout_raw = open('raw.txt', 'w') fout_sent = open('sent.txt', 'w') for bucket_id in xrange(len(_buckets)): bucket_size = len(dev_set[bucket_id]) offsets = np.arange(0, bucket_size, batch_size) for batch_offset in offsets: all_examples = dev_set[bucket_id][batch_offset:batch_offset+batch_size] model_dev.batch_size = len(all_examples) token_ids = [x[0] for x in all_examples] gold_ids = [x[1] for x in all_examples] dec_ids = [[]] * len(token_ids) encoder_inputs, decoder_inputs, target_weights = model_dev.get_decode_batch( {bucket_id: zip(token_ids, dec_ids)}, bucket_id) _, _, output_logits = model_dev.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True) outputs = [np.argmax(logit, axis=1) for logit in output_logits] to_decode = np.array(outputs).T num_valid = 0 for sent_id in range(to_decode.shape[0]): parse = list(to_decode[sent_id, :]) if data_utils.EOS_ID in parse: parse = parse[:parse.index(data_utils.EOS_ID)] # raw decoded parse # print(parse) decoded_parse = [] for output in parse: if output < len(rev_parse_vocab): decoded_parse.append(tf.compat.as_str(rev_parse_vocab[output])) else: decoded_parse.append("_UNK") # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse] # add brackets for tree balance parse_br, valid = add_brackets(decoded_parse) num_valid += valid # get gold parse, gold sentence gold_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id]] sent_text = [tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id]] # parse with also matching "XX" length parse_mx = match_length(parse_br, sent_text) parse_mx = delete_empty_constituents(parse_mx) to_write_gold = merge_sent_tree(gold_parse[:-1], sent_text) # account for EOS to_write_br = merge_sent_tree(parse_br, sent_text) to_write_mx = merge_sent_tree(parse_mx, sent_text) fout_gold.write('{}\n'.format(' '.join(to_write_gold))) fout_br.write('{}\n'.format(' '.join(to_write_br))) fout_mx.write('{}\n'.format(' '.join(to_write_mx))) fout_sent.write('{}\n'.format(' '.join(sent_text))) fout_raw.write('{}\n'.format(' '.join(decoded_parse))) # Write to file fout_gold.close() fout_br.close() fout_mx.close() fout_sent.close() fout_raw.close()
decoded_parse.append("_UNK") # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse] # add brackets for tree balance parse_br, valid = add_brackets(decoded_parse) num_valid += valid # get gold parse, gold sentence gold_parse = [ tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id] ] sent_text = [ tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id] ] # parse with also matching "XX" length parse_mx = match_length(parse_br, sent_text) parse_mx = delete_empty_constituents(parse_mx) to_write_gold = merge_sent_tree(gold_parse, sent_text) # account for EOS to_write_br = merge_sent_tree(parse_br, sent_text) to_write_mx = merge_sent_tree(parse_mx, sent_text) def do_evalb(model_dev, sess, dev_set, eval_batch_size): gold_file_name = os.path.join(train_dir, 'partial.gold.txt') # file with matched brackets decoded_br_file_name = os.path.join(train_dir, 'partial.decoded.br.txt') # file filler XX help as well decoded_mx_file_name = os.path.join(train_dir, 'partial.decoded.mx.txt')
sents_vocab_path) _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path) _buckets = [(10, 40), (25, 85), (40, 150)] gold_file = 'gold_dev.txt' baseline_file = 'baseline_dev.txt' fg = open(gold_file, 'w') fb = open(baseline_file, 'w') for bucket_id in xrange(len(_buckets)): for sentence in dev_set[bucket_id]: slength = len(sentence[0]) toks = sentence[0] gold = sentence[1] gold_parse = [ tf.compat.as_str(rev_parse_vocab[output]) for output in gold ] sent_text = [ tf.compat.as_str(rev_sent_vocab[output]) for output in toks ] prediction = random.choice(baseline_dict[slength]) parse_mx = match_length(prediction.split(), sent_text) to_write_mx = merge_sent_tree(parse_mx, sent_text) to_write_gold = merge_sent_tree(gold_parse, sent_text) fg.write('{}\n'.format(' '.join(to_write_gold))) fb.write('{}\n'.format(' '.join(to_write_mx))) fb.close() fg.close()