Ejemplo n.º 1
0
def write_decode(model_dev,
                 sess,
                 dev_set,
                 eval_batch_size,
                 globstep,
                 eval_now=False):
    # Load vocabularies.
    sents_vocab_path = os.path.join(FLAGS.data_dir, FLAGS.source_vocab_file)
    parse_vocab_path = os.path.join(FLAGS.data_dir, FLAGS.target_vocab_file)
    sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(
        sents_vocab_path)
    _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

    # current progress
    stepname = str(globstep)
    gold_file_name = os.path.join(FLAGS.train_dir,
                                  'gold-step' + stepname + '.txt')
    print(gold_file_name)
    # file with matched brackets
    decoded_br_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-br-step' + stepname + '.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-mx-step' + stepname + '.txt')

    fout_gold = open(gold_file_name, 'w')
    fout_br = open(decoded_br_file_name, 'w')
    fout_mx = open(decoded_mx_file_name, 'w')

    num_dev_sents = 0
    for bucket_id in xrange(len(_buckets)):
        bucket_size = len(dev_set[bucket_id])
        offsets = np.arange(0, bucket_size, eval_batch_size)
        for batch_offset in offsets:
            all_examples = dev_set[bucket_id][batch_offset:batch_offset +
                                              eval_batch_size]
            model_dev.batch_size = len(all_examples)
            token_ids = [x[0] for x in all_examples]
            partition = [x[2] for x in all_examples]
            speech_feats = [x[3] for x in all_examples]
            gold_ids = [x[1] for x in all_examples]
            dec_ids = [[]] * len(token_ids)
            text_encoder_inputs, speech_encoder_inputs, decoder_inputs, target_weights, text_seq_len, speech_seq_len = model_dev.get_batch_zero(
                    {bucket_id: zip(token_ids, dec_ids, partition, speech_feats)}, \
                            bucket_id, batch_offset)
            _, _, output_logits = model_dev.step(
                sess, [text_encoder_inputs, speech_encoder_inputs],
                decoder_inputs, target_weights, text_seq_len, speech_seq_len,
                bucket_id, True)
            outputs = [np.argmax(logit, axis=1) for logit in output_logits]
            to_decode = np.array(outputs).T
            num_dev_sents += to_decode.shape[0]
            num_valid = 0
            for sent_id in range(to_decode.shape[0]):
                parse = list(to_decode[sent_id, :])
                if data_utils.EOS_ID in parse:
                    parse = parse[:parse.index(data_utils.EOS_ID)]
                decoded_parse = []
                for output in parse:
                    if output < len(rev_parse_vocab):
                        decoded_parse.append(
                            tf.compat.as_str(rev_parse_vocab[output]))
                    else:
                        decoded_parse.append("_UNK")
                # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
                # add brackets for tree balance
                parse_br, valid = add_brackets(decoded_parse)
                num_valid += valid
                # get gold parse, gold sentence
                gold_parse = [
                    tf.compat.as_str(rev_parse_vocab[output])
                    for output in gold_ids[sent_id]
                ]
                sent_text = [
                    tf.compat.as_str(rev_sent_vocab[output])
                    for output in token_ids[sent_id]
                ]
                # parse with also matching "XX" length
                parse_mx = match_length(parse_br, sent_text)
                parse_mx = delete_empty_constituents(parse_mx)

                to_write_gold = merge_sent_tree(gold_parse[:-1],
                                                sent_text)  # account for EOS
                to_write_br = merge_sent_tree(parse_br, sent_text)
                to_write_mx = merge_sent_tree(parse_mx, sent_text)

                fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
                fout_br.write('{}\n'.format(' '.join(to_write_br)))
                fout_mx.write('{}\n'.format(' '.join(to_write_mx)))

    # Write to file
    fout_gold.close()
    fout_br.close()
    fout_mx.close()

    if eval_now:
        correction_types = ["Bracket only", "Matched XX"]
        corrected_files = [decoded_br_file_name, decoded_mx_file_name]

        for c_type, c_file in zip(correction_types, corrected_files):
            cmd = [evalb_path, '-p', prm_file, gold_file_name, c_file]
            p = subprocess.Popen(cmd,
                                 stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
            out, err = p.communicate()
            out_lines = out.split("\n")
            vv = [x for x in out_lines if "Number of Valid sentence " in x]
            s1 = float(vv[0].split()[-1])
            m_br, g_br, t_br = process_eval(out_lines, num_dev_sents)

            recall = float(m_br) / float(g_br)
            prec = float(m_br) / float(t_br)
            f_score = 2 * recall * prec / (recall + prec)

            print("%s -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" %
                  (c_type, s1, prec, recall, f_score))
Ejemplo n.º 2
0
def do_evalb(model_dev, sess, dev_set, eval_batch_size):
    gold_file_name = os.path.join(FLAGS.train_dir, 'partial.gold.txt')
    # file with matched brackets
    decoded_br_file_name = os.path.join(FLAGS.train_dir,
                                        'partial.decoded.br.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(FLAGS.train_dir,
                                        'partial.decoded.mx.txt')

    num_sents = []
    num_valid_br = []
    num_valid_mx = []
    br = []
    mx = []

    for bucket_id in xrange(len(_buckets)):
        bucket_size = len(dev_set[bucket_id])
        offsets = np.arange(0, bucket_size, eval_batch_size)
        for batch_offset in offsets:
            fout_gold = open(gold_file_name, 'w')
            fout_br = open(decoded_br_file_name, 'w')
            fout_mx = open(decoded_mx_file_name, 'w')
            all_examples = dev_set[bucket_id][batch_offset:batch_offset +
                                              eval_batch_size]
            model_dev.batch_size = len(all_examples)
            token_ids = [x[0] for x in all_examples]
            gold_ids = [x[1] for x in all_examples]
            mfccs = [x[2] for x in all_examples]
            dec_ids = [[]] * len(token_ids)
            text_encoder_inputs, speech_encoder_inputs, decoder_inputs, target_weights, seq_len = model_dev.get_batch(
                {bucket_id: zip(token_ids, dec_ids, mfccs)}, bucket_id)
            _, _, output_logits = model_dev.step(
                sess, [text_encoder_inputs, speech_encoder_inputs],
                decoder_inputs, target_weights, seq_len, bucket_id, True)
            outputs = [np.argmax(logit, axis=1) for logit in output_logits]
            to_decode = np.array(outputs).T
            num_valid = 0
            num_sents.append(to_decode.shape[0])
            for sent_id in range(to_decode.shape[0]):
                parse = list(to_decode[sent_id, :])
                if data_utils.EOS_ID in parse:
                    parse = parse[:parse.index(data_utils.EOS_ID)]
                # raw decoded parse
                # print(parse)
                decoded_parse = []
                for output in parse:
                    if output < len(rev_parse_vocab):
                        decoded_parse.append(
                            tf.compat.as_str(rev_parse_vocab[output]))
                    else:
                        decoded_parse.append("_UNK")
                # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
                # add brackets for tree balance
                parse_br, valid = add_brackets(decoded_parse)
                num_valid += valid
                # get gold parse, gold sentence
                gold_parse = [
                    tf.compat.as_str(rev_parse_vocab[output])
                    for output in gold_ids[sent_id]
                ]
                sent_text = [
                    tf.compat.as_str(rev_sent_vocab[output])
                    for output in token_ids[sent_id]
                ]
                # parse with also matching "XX" length
                parse_mx = match_length(parse_br, sent_text)
                parse_mx = delete_empty_constituents(parse_mx)
                to_write_gold = merge_sent_tree(gold_parse, sent_text)
                to_write_br = merge_sent_tree(parse_br, sent_text)
                to_write_mx = merge_sent_tree(parse_mx, sent_text)
                fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
                fout_br.write('{}\n'.format(' '.join(to_write_br)))
                fout_mx.write('{}\n'.format(' '.join(to_write_mx)))

            # call evalb
            fout_gold.close()
            fout_br.close()
            fout_mx.close()

            # evaluate current batch
            cmd = [
                evalb_path, '-p', prm_file, gold_file_name,
                decoded_br_file_name
            ]
            p = subprocess.Popen(cmd,
                                 stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
            out, err = p.communicate()
            out_lines = out.split("\n")
            vv = [x for x in out_lines if "Number of Valid sentence " in x]
            s1 = float(vv[0].split()[-1])
            num_valid_br.append(s1)
            m_br, g_br, t_br = process_eval(out_lines, to_decode.shape[0])
            br.append([m_br, g_br, t_br])

            cmd = [
                evalb_path, '-p', prm_file, gold_file_name,
                decoded_mx_file_name
            ]
            p = subprocess.Popen(cmd,
                                 stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
            out, err = p.communicate()
            out_lines = out.split("\n")
            vv = [x for x in out_lines if "Number of Valid sentence " in x]
            s2 = float(vv[0].split()[-1])
            num_valid_mx.append(s2)
            m_mx, g_mx, t_mx = process_eval(out_lines, to_decode.shape[0])
            mx.append([m_mx, g_mx, t_mx])

    br_all = np.array(br)
    mx_all = np.array(mx)
    sum_br_pre = sum(br_all[:, 0]) / sum(br_all[:, 2])
    sum_br_rec = sum(br_all[:, 0]) / sum(br_all[:, 1])
    sum_br_f1 = 2 * sum_br_pre * sum_br_rec / (sum_br_rec + sum_br_pre)
    sum_mx_pre = sum(mx_all[:, 0]) / sum(mx_all[:, 2])
    sum_mx_rec = sum(mx_all[:, 0]) / sum(mx_all[:, 1])
    sum_mx_f1 = 2 * sum_mx_pre * sum_mx_rec / (sum_mx_rec + sum_mx_pre)
    br_valid = sum(num_valid_br)
    mx_valid = sum(num_valid_mx)
    print(
        "Bracket only -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" %
        (br_valid, sum_br_pre, sum_br_rec, sum_br_f1))
    print(
        "Matched XX   -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" %
        (mx_valid, sum_mx_pre, sum_mx_rec, sum_mx_f1))
Ejemplo n.º 3
0
def write_decode(model_dev, sess, dev_set, eval_batch_size, globstep):
    # Load vocabularies.
    stepname = str(globstep)
    gold_file_name = os.path.join(FLAGS.train_dir,
                                  'gold-step' + stepname + '.txt')
    print(gold_file_name)
    # file with matched brackets
    decoded_br_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-br-step' + stepname + '.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-mx-step' + stepname + '.txt')

    fout_gold = open(gold_file_name, 'w')
    fout_br = open(decoded_br_file_name, 'w')
    fout_mx = open(decoded_mx_file_name, 'w')

    for bucket_id in xrange(len(_buckets)):
        bucket_size = len(dev_set[bucket_id])
        offsets = np.arange(0, bucket_size, eval_batch_size)
        for batch_offset in offsets:
            all_examples = dev_set[bucket_id][batch_offset:batch_offset +
                                              eval_batch_size]
            model_dev.batch_size = len(all_examples)
            token_ids = [x[0] for x in all_examples]
            mfccs = [x[2] for x in all_examples]
            gold_ids = [x[1] for x in all_examples]
            dec_ids = [[]] * len(token_ids)
            text_encoder_inputs, speech_encoder_inputs, decoder_inputs, target_weights, seq_len = model_dev.get_batch(
                {bucket_id: zip(token_ids, dec_ids, mfccs)}, bucket_id)
            _, _, output_logits = model_dev.step(
                sess, [text_encoder_inputs, speech_encoder_inputs],
                decoder_inputs, target_weights, seq_len, bucket_id, True)
            outputs = [np.argmax(logit, axis=1) for logit in output_logits]
            to_decode = np.array(outputs).T
            num_valid = 0
            for sent_id in range(to_decode.shape[0]):
                parse = list(to_decode[sent_id, :])
                if data_utils.EOS_ID in parse:
                    parse = parse[:parse.index(data_utils.EOS_ID)]
                # raw decoded parse
                # print(parse)
                decoded_parse = []
                for output in parse:
                    if output < len(rev_parse_vocab):
                        decoded_parse.append(
                            tf.compat.as_str(rev_parse_vocab[output]))
                    else:
                        decoded_parse.append("_UNK")
                # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
                # add brackets for tree balance
                parse_br, valid = add_brackets(decoded_parse)
                num_valid += valid
                # get gold parse, gold sentence
                gold_parse = [
                    tf.compat.as_str(rev_parse_vocab[output])
                    for output in gold_ids[sent_id]
                ]
                sent_text = [
                    tf.compat.as_str(rev_sent_vocab[output])
                    for output in token_ids[sent_id]
                ]
                # parse with also matching "XX" length
                parse_mx = match_length(parse_br, sent_text)
                parse_mx = delete_empty_constituents(parse_mx)

                to_write_gold = merge_sent_tree(gold_parse,
                                                sent_text)  # account for EOS
                to_write_br = merge_sent_tree(parse_br, sent_text)
                to_write_mx = merge_sent_tree(parse_mx, sent_text)

                fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
                fout_br.write('{}\n'.format(' '.join(to_write_br)))
                fout_mx.write('{}\n'.format(' '.join(to_write_mx)))

    # Write to file
    fout_gold.close()
    fout_br.close()
    fout_mx.close()
Ejemplo n.º 4
0
def write_decode(model_dev, sess, dev_set):
    # Load vocabularies.
    sents_vocab_path = os.path.join(data_dir, "vocab%d.sents" % 90000)
    parse_vocab_path = os.path.join(data_dir, "vocab%d.parse" % 128)
    sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(
        sents_vocab_path)
    _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

    gold_file_name = os.path.join(train_dir, 'debug.gold.txt')
    # file with matched brackets
    decoded_br_file_name = os.path.join(train_dir, 'debug.decoded.br.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(train_dir, 'debug.decoded.mx.txt')

    fout_gold = open(gold_file_name, 'w')
    fout_br = open(decoded_br_file_name, 'w')
    fout_mx = open(decoded_mx_file_name, 'w')

    for bucket_id in xrange(len(_buckets)):
        bucket_size = len(dev_set[bucket_id])
        offsets = np.arange(0, bucket_size, batch_size)
        for batch_offset in offsets:
            all_examples = dev_set[bucket_id][batch_offset:batch_offset +
                                              batch_size]
            model_dev.batch_size = len(all_examples)
            token_ids = [x[0] for x in all_examples]
            gold_ids = [x[1] for x in all_examples]
            dec_ids = [[]] * len(token_ids)
            encoder_inputs, decoder_inputs, target_weights = model_dev.get_decode_batch(
                {bucket_id: zip(token_ids, dec_ids)}, bucket_id)
            _, _, output_logits = model_dev.step(sess, encoder_inputs,
                                                 decoder_inputs,
                                                 target_weights, bucket_id,
                                                 True)
            outputs = [np.argmax(logit, axis=1) for logit in output_logits]
            to_decode = np.array(outputs).T
            num_valid = 0
            for sent_id in range(to_decode.shape[0]):
                parse = list(to_decode[sent_id, :])
                if data_utils.EOS_ID in parse:
                    parse = parse[:parse.index(data_utils.EOS_ID)]
                # raw decoded parse
                # print(parse)
                decoded_parse = []
                for output in parse:
                    if output < len(rev_parse_vocab):
                        decoded_parse.append(
                            tf.compat.as_str(rev_parse_vocab[output]))
                    else:
                        decoded_parse.append("_UNK")
                # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
                # add brackets for tree balance
                parse_br, valid = add_brackets(decoded_parse)
                num_valid += valid
                # get gold parse, gold sentence
                gold_parse = [
                    tf.compat.as_str(rev_parse_vocab[output])
                    for output in gold_ids[sent_id]
                ]
                sent_text = [
                    tf.compat.as_str(rev_sent_vocab[output])
                    for output in token_ids[sent_id]
                ]
                # parse with also matching "XX" length
                parse_mx = match_length(parse_br, sent_text)

                to_write_gold = merge_sent_tree(gold_parse[:-1],
                                                sent_text)  # account for EOS
                to_write_br = merge_sent_tree(parse_br, sent_text)
                to_write_mx = merge_sent_tree(parse_mx, sent_text)

                fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
                fout_br.write('{}\n'.format(' '.join(to_write_br)))
                fout_mx.write('{}\n'.format(' '.join(to_write_mx)))

    # Write to file
    fout_gold.close()
    fout_br.close()
    fout_mx.close()