Esempio n. 1
0
def do_evalb(model_dev, model, sess, dev_set):  
  # Load vocabularies.
  eval_batch_size = 50
  sents_vocab_path = os.path.join(FLAGS.data_dir,"vocab%d.sents" % FLAGS.input_vocab_size)
  parse_vocab_path = os.path.join(FLAGS.data_dir,"vocab%d.parse" % FLAGS.output_vocab_size)
  sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(sents_vocab_path)
  _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

  gold_file_name = os.path.join(FLAGS.train_dir, 'dev.gold.txt')
  # file with matched brackets
  decoded_br_file_name = os.path.join(FLAGS.train_dir, 'dev.decoded.br.txt')
  # file filler XX help as well
  decoded_mx_file_name = os.path.join(FLAGS.train_dir, 'dev.decoded.mx.txt')
  
  print("Doing evalb")
  print("Debug - step: ", model.global_step.eval())
  
  num_sents = []
  num_valid_br = []
  num_valid_mx = []
  br = []
  mx = []

  for bucket_id in xrange(len(_buckets)):
    bucket_size = len(dev_set[bucket_id])
    offsets = np.arange(0, bucket_size, eval_batch_size) 
    for batch_offset in offsets:
        fout_gold = open(gold_file_name, 'w')
        fout_br = open(decoded_br_file_name, 'w')
        fout_mx = open(decoded_mx_file_name, 'w')
        all_examples = dev_set[bucket_id][batch_offset:batch_offset + eval_batch_size]
        model_dev.batch_size = len(all_examples)        
        token_ids = [x[0] for x in all_examples]
        gold_ids = [x[1] for x in all_examples]
        dec_ids = [[]] * len(token_ids)
        encoder_inputs, decoder_inputs, target_weights, seq_len = model_dev.get_decode_batch(
                {bucket_id: zip(token_ids, dec_ids)}, bucket_id)
        _, _, output_logits = model_dev.step(sess, encoder_inputs, decoder_inputs, target_weights, seq_len, bucket_id, True)
        outputs = [np.argmax(logit, axis=1) for logit in output_logits]
        to_decode = np.array(outputs).T
        num_valid = 0
        num_sents.append(to_decode.shape[0])
        for sent_id in range(to_decode.shape[0]):
          parse = list(to_decode[sent_id, :])
          if data_utils.EOS_ID in parse:
            parse = parse[:parse.index(data_utils.EOS_ID)]
          # raw decoded parse
          # print(parse)
          decoded_parse = []
          for output in parse:
              if output < len(rev_parse_vocab):
                decoded_parse.append(tf.compat.as_str(rev_parse_vocab[output]))
              else:
                decoded_parse.append("_UNK") 
          # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
          # add brackets for tree balance
          parse_br, valid = add_brackets(decoded_parse)
          num_valid += valid
          # get gold parse, gold sentence
          gold_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id]]
          sent_text = [tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id]]
          # parse with also matching "XX" length
          parse_mx = match_length(parse_br, sent_text)

          to_write_gold = merge_sent_tree(gold_parse[:-1], sent_text) # account for EOS
          to_write_br = merge_sent_tree(parse_br, sent_text)
          to_write_mx = merge_sent_tree(parse_mx, sent_text)

          fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
          fout_br.write('{}\n'.format(' '.join(to_write_br)))
          fout_mx.write('{}\n'.format(' '.join(to_write_mx)))
          
        # call evalb
        fout_gold.close()
        fout_br.close()
        fout_mx.close()  
   
        # evaluate current batch
        cmd = [evalb_path, '-p', prm_file, gold_file_name, decoded_br_file_name]
        p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = p.communicate()
        out_lines = out.split("\n")
        vv = [x for x in out_lines if "Number of Valid sentence " in x]
        s1 = float(vv[0].split()[-1])
        num_valid_br.append(s1)
        m_br, g_br, t_br = process_eval(out_lines, to_decode.shape[0])
        br.append([m_br, g_br, t_br])

        cmd = [evalb_path, '-p', prm_file, gold_file_name, decoded_mx_file_name]
        p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = p.communicate()
        out_lines = out.split("\n")
        vv = [x for x in out_lines if "Number of Valid sentence " in x]
        s2 = float(vv[0].split()[-1])
        num_valid_mx.append(s2)
        m_mx, g_mx, t_mx = process_eval(out_lines, to_decode.shape[0])
        mx.append([m_mx, g_mx, t_mx])

  br_all = np.array(br)
  mx_all = np.array(mx)
  sum_br_pre = sum(br_all[:,0]) / sum(br_all[:,2])
  sum_br_rec = sum(br_all[:,0]) / sum(br_all[:,1])
  sum_br_f1 = 2 * sum_br_pre*sum_br_rec / (sum_br_rec + sum_br_pre)
  sum_mx_pre = sum(mx_all[:,0]) / sum(mx_all[:,2])
  sum_mx_rec = sum(mx_all[:,0]) / sum(mx_all[:,1])
  sum_mx_f1 = 2 * sum_mx_pre*sum_mx_rec / (sum_mx_rec + sum_mx_pre)
  br_valid = sum(num_valid_br)
  mx_valid = sum(num_valid_mx)
  print("Bracket only -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" 
          %(br_valid, sum_br_pre, sum_br_rec, sum_br_f1) ) 
  print("Matched XX   -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" 
          %(mx_valid, sum_mx_pre, sum_mx_rec, sum_mx_f1) ) 

  if br_valid >= 5000 and sum_br_f1 >= 80: 
    print("Very good model")
    checkpoint_path = os.path.join(FLAGS.train_dir, "good-model.ckpt")
    model.saver.save(sess, checkpoint_path, global_step=model.global_step)
Esempio n. 2
0
def write_decode(model_dev,
                 sess,
                 dev_set,
                 eval_batch_size,
                 globstep,
                 eval_now=False):
    # Load vocabularies.
    sents_vocab_path = os.path.join(FLAGS.data_dir, FLAGS.source_vocab_file)
    parse_vocab_path = os.path.join(FLAGS.data_dir, FLAGS.target_vocab_file)
    sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(
        sents_vocab_path)
    _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

    # current progress
    stepname = str(globstep)
    gold_file_name = os.path.join(FLAGS.train_dir,
                                  'gold-step' + stepname + '.txt')
    print(gold_file_name)
    # file with matched brackets
    decoded_br_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-br-step' + stepname + '.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-mx-step' + stepname + '.txt')

    fout_gold = open(gold_file_name, 'w')
    fout_br = open(decoded_br_file_name, 'w')
    fout_mx = open(decoded_mx_file_name, 'w')

    num_dev_sents = 0
    for bucket_id in xrange(len(_buckets)):
        bucket_size = len(dev_set[bucket_id])
        offsets = np.arange(0, bucket_size, eval_batch_size)
        for batch_offset in offsets:
            all_examples = dev_set[bucket_id][batch_offset:batch_offset +
                                              eval_batch_size]
            model_dev.batch_size = len(all_examples)
            token_ids = [x[0] for x in all_examples]
            partition = [x[2] for x in all_examples]
            speech_feats = [x[3] for x in all_examples]
            pbfs = [x[4] for x in all_examples]
            pafs = [x[5] for x in all_examples]
            gold_ids = [x[1] for x in all_examples]
            dec_ids = [[]] * len(token_ids)
            text_encoder_inputs, speech_encoder_inputs, pause_bef, pause_aft, \
                    decoder_inputs, target_weights, \
                    text_seq_len, speech_seq_len = model_dev.get_batch(\
                    {bucket_id: zip(token_ids, dec_ids, partition, speech_feats, pbfs, pafs)}, \
                    bucket_id, batch_offset, FLAGS.use_speech)
            _, _, output_logits = model_dev.step(sess, [text_encoder_inputs, speech_encoder_inputs, \
                    pause_bef, pause_aft],\
                    decoder_inputs, target_weights, text_seq_len, speech_seq_len, \
                    bucket_id, True, FLAGS.use_speech)
            outputs = [np.argmax(logit, axis=1) for logit in output_logits]
            to_decode = np.array(outputs).T
            num_dev_sents += to_decode.shape[0]
            num_valid = 0
            for sent_id in range(to_decode.shape[0]):
                parse = list(to_decode[sent_id, :])
                if data_utils.EOS_ID in parse:
                    parse = parse[:parse.index(data_utils.EOS_ID)]
                decoded_parse = []
                for output in parse:
                    if output < len(rev_parse_vocab):
                        decoded_parse.append(
                            tf.compat.as_str(rev_parse_vocab[output]))
                    else:
                        decoded_parse.append("_UNK")
                # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
                # add brackets for tree balance
                parse_br, valid = add_brackets(decoded_parse)
                num_valid += valid
                # get gold parse, gold sentence
                gold_parse = [
                    tf.compat.as_str(rev_parse_vocab[output])
                    for output in gold_ids[sent_id]
                ]
                sent_text = [
                    tf.compat.as_str(rev_sent_vocab[output])
                    for output in token_ids[sent_id]
                ]
                # parse with also matching "XX" length
                parse_mx = match_length(parse_br, sent_text)
                parse_mx = delete_empty_constituents(parse_mx)

                to_write_gold = merge_sent_tree(gold_parse[:-1],
                                                sent_text)  # account for EOS
                to_write_br = merge_sent_tree(parse_br, sent_text)
                to_write_mx = merge_sent_tree(parse_mx, sent_text)

                fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
                fout_br.write('{}\n'.format(' '.join(to_write_br)))
                fout_mx.write('{}\n'.format(' '.join(to_write_mx)))

    # Write to file
    fout_gold.close()
    fout_br.close()
    fout_mx.close()

    f_score_mx = -1.0

    if eval_now:
        f_score_mx = 0.0
        correction_types = ["Bracket only", "Matched XX"]
        corrected_files = [decoded_br_file_name, decoded_mx_file_name]

        for c_type, c_file in zip(correction_types, corrected_files):
            cmd = [evalb_path, '-p', prm_file, gold_file_name, c_file]
            p = subprocess.Popen(cmd,
                                 stdin=subprocess.PIPE,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
            out, err = p.communicate()
            out_lines = out.split("\n")
            vv = [x for x in out_lines if "Number of Valid sentence " in x]
            if len(vv) == 0:
                return 0.0
            s1 = float(vv[0].split()[-1])
            m_br, g_br, t_br = process_eval(out_lines, num_dev_sents)

            try:
                recall = float(m_br) / float(g_br)
                prec = float(m_br) / float(t_br)
                f_score = 2 * recall * prec / (recall + prec)
            except ZeroDivisionError as e:
                recall, prec, f_score = 0.0, 0.0, 0.0

            print("%s -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" \
                    %(c_type, s1, prec, recall, f_score) )
            sys.stdout.flush()
            if "XX" in c_type:
                f_score_mx = f_score
    return f_score_mx
Esempio n. 3
0
def write_decode(model_dev, sess, dev_set, eval_batch_size, globstep):
    # Load vocabularies.
    stepname = str(globstep)
    gold_file_name = os.path.join(FLAGS.train_dir,
                                  'gold-step' + stepname + '.txt')
    print(gold_file_name)
    # file with matched brackets
    decoded_br_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-br-step' + stepname + '.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(FLAGS.train_dir,
                                        'decoded-mx-step' + stepname + '.txt')

    fout_gold = open(gold_file_name, 'w')
    fout_br = open(decoded_br_file_name, 'w')
    fout_mx = open(decoded_mx_file_name, 'w')

    for bucket_id in xrange(len(_buckets)):
        bucket_size = len(dev_set[bucket_id])
        offsets = np.arange(0, bucket_size, eval_batch_size)
        for batch_offset in offsets:
            all_examples = dev_set[bucket_id][batch_offset:batch_offset +
                                              eval_batch_size]
            model_dev.batch_size = len(all_examples)
            token_ids = [x[0] for x in all_examples]
            mfccs = [x[2] for x in all_examples]
            gold_ids = [x[1] for x in all_examples]
            dec_ids = [[]] * len(token_ids)
            text_encoder_inputs, speech_encoder_inputs, decoder_inputs, target_weights, seq_len = model_dev.get_batch(
                {bucket_id: zip(token_ids, dec_ids, mfccs)}, bucket_id)
            _, _, output_logits = model_dev.step(
                sess, [text_encoder_inputs, speech_encoder_inputs],
                decoder_inputs, target_weights, seq_len, bucket_id, True)
            outputs = [np.argmax(logit, axis=1) for logit in output_logits]
            to_decode = np.array(outputs).T
            num_valid = 0
            for sent_id in range(to_decode.shape[0]):
                parse = list(to_decode[sent_id, :])
                if data_utils.EOS_ID in parse:
                    parse = parse[:parse.index(data_utils.EOS_ID)]
                # raw decoded parse
                # print(parse)
                decoded_parse = []
                for output in parse:
                    if output < len(rev_parse_vocab):
                        decoded_parse.append(
                            tf.compat.as_str(rev_parse_vocab[output]))
                    else:
                        decoded_parse.append("_UNK")
                # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
                # add brackets for tree balance
                parse_br, valid = add_brackets(decoded_parse)
                num_valid += valid
                # get gold parse, gold sentence
                gold_parse = [
                    tf.compat.as_str(rev_parse_vocab[output])
                    for output in gold_ids[sent_id]
                ]
                sent_text = [
                    tf.compat.as_str(rev_sent_vocab[output])
                    for output in token_ids[sent_id]
                ]
                # parse with also matching "XX" length
                parse_mx = match_length(parse_br, sent_text)
                parse_mx = delete_empty_constituents(parse_mx)

                to_write_gold = merge_sent_tree(gold_parse,
                                                sent_text)  # account for EOS
                to_write_br = merge_sent_tree(parse_br, sent_text)
                to_write_mx = merge_sent_tree(parse_mx, sent_text)

                fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
                fout_br.write('{}\n'.format(' '.join(to_write_br)))
                fout_mx.write('{}\n'.format(' '.join(to_write_mx)))

    # Write to file
    fout_gold.close()
    fout_br.close()
    fout_mx.close()
Esempio n. 4
0
def write_decode(model_dev, sess, dev_set, get_results=False):
    """Perform evaluatio."""
    # Load vocabularies.
    sents_vocab_path = os.path.join(FLAGS.vocab_dir,
                                    FLAGS.source_vocab_file['word'])
    parse_vocab_path = os.path.join(FLAGS.vocab_dir, FLAGS.target_vocab_file)
    sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(
        sents_vocab_path)
    _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

    gold_file_name = os.path.join(FLAGS.train_dir, 'gold.txt')
    # file with matched brackets
    decoded_br_file_name = os.path.join(FLAGS.train_dir, 'decoded.br.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(FLAGS.train_dir, 'decoded.mx.txt')
    sent_id_file_name = os.path.join(FLAGS.train_dir, 'sent_id.txt')

    fout_gold = open(gold_file_name, 'w')
    fout_br = open(decoded_br_file_name, 'w')
    fout_mx = open(decoded_mx_file_name, 'w')
    fsent_id = open(sent_id_file_name, 'w')

    num_dev_sents = 0
    for bucket_id in xrange(len(_buckets)):
        bucket_size = len(dev_set[bucket_id])
        offsets = np.arange(0, bucket_size, FLAGS.batch_size)
        for batch_offset in offsets:
            all_examples = dev_set[bucket_id][batch_offset:batch_offset +
                                              FLAGS.batch_size]
            model_dev.batch_size = len(all_examples)
            sent_ids = [x[0] for x in all_examples]
            for sent_id in sent_ids:
                fsent_id.write(sent_id + "\n")

            token_ids = [x[1] for x in all_examples]
            gold_ids = [x[2] for x in all_examples]
            dec_ids = [[]] * len(token_ids)
            encoder_inputs, seq_len, decoder_inputs, seq_len_target =\
                model_dev.get_batch({bucket_id:
                                     zip(sent_ids, token_ids, dec_ids)})
            output_logits = model_dev.step(sess, encoder_inputs, seq_len,
                                           decoder_inputs, seq_len_target)

            outputs = np.argmax(output_logits, axis=1)
            outputs = np.reshape(
                outputs, (max(seq_len_target), model_dev.batch_size))  # T*B

            to_decode = np.array(outputs).T
            num_dev_sents += to_decode.shape[0]
            for sent_id in range(to_decode.shape[0]):
                parse = list(to_decode[sent_id, :])
                if data_utils.EOS_ID in parse:
                    parse = parse[:parse.index(data_utils.EOS_ID)]
                decoded_parse = []
                for output in parse:
                    if output < len(rev_parse_vocab):
                        decoded_parse.append(
                            tf.compat.as_str(rev_parse_vocab[output]))
                    else:
                        decoded_parse.append("_UNK")

                parse_br, valid = add_brackets(decoded_parse)
                # get gold parse, gold sentence
                gold_parse = [
                    tf.compat.as_str(rev_parse_vocab[output])
                    for output in gold_ids[sent_id]
                ]
                sent_text = [
                    tf.compat.as_str(rev_sent_vocab[output])
                    for output in token_ids[sent_id]['word']
                ]
                # parse with also matching "XX" length
                parse_mx = match_length(parse_br, sent_text)
                parse_mx = delete_empty_constituents(parse_mx)
                # account for EOS
                to_write_gold = merge_sent_tree(gold_parse[:-1], sent_text)
                to_write_br = merge_sent_tree(parse_br, sent_text)
                to_write_mx = merge_sent_tree(parse_mx, sent_text)

                fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
                fout_br.write('{}\n'.format(' '.join(to_write_br)))
                fout_mx.write('{}\n'.format(' '.join(to_write_mx)))

    # Write to file
    fout_gold.close()
    fout_br.close()
    fout_mx.close()
    fsent_id.close()

    f_score_mx = 0.0
    correction_types = ["Bracket only", "Matched XX"]
    corrected_files = [decoded_br_file_name, decoded_mx_file_name]

    for c_type, c_file in zip(correction_types, corrected_files):
        cmd = [evalb_path, '-p', prm_file, gold_file_name, c_file]
        p = subprocess.Popen(cmd,
                             stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)
        out, err = p.communicate()
        with open(os.path.join(FLAGS.train_dir, "log.txt"), "w") as log_f:
            log_f.write(out)
        out_lines = out.split("\n")
        vv = [x for x in out_lines if "Number of Valid sentence " in x]
        if len(vv) == 0:
            return 0.0
        s1 = float(vv[0].split()[-1])
        m_br, g_br, t_br = process_eval(out_lines, num_dev_sents)

        try:
            recall = float(m_br) / float(g_br)
            prec = float(m_br) / float(t_br)
            f_score = 2 * recall * prec / (recall + prec)
        except ZeroDivisionError:
            recall, prec, f_score = 0.0, 0.0, 0.0

        print("%s -- Num valid sentences: %d; p: %.4f; r: %.4f; f1: %.4f" %
              (c_type, s1, prec, recall, f_score))
        if "XX" in c_type:
            f_score_mx = f_score

    return f_score_mx
def write_decode(model_dev, sess, dev_set):  
  # Load vocabularies.
  sents_vocab_path = os.path.join(data_dir,"vocab%d.sents" % 90000)
  parse_vocab_path = os.path.join(data_dir,"vocab%d.parse" % 128)
  sents_vocab, rev_sent_vocab = data_utils.initialize_vocabulary(sents_vocab_path)
  _, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

  gold_file_name = os.path.join(train_dir, 'debug.gold.txt')
  # file with matched brackets
  decoded_br_file_name = os.path.join(train_dir, 'debug.decoded.br.txt')
  # file filler XX help as well
  decoded_mx_file_name = os.path.join(train_dir, 'debug.decoded.mx.txt')
  
  fout_gold = open(gold_file_name, 'w')
  fout_br = open(decoded_br_file_name, 'w')
  fout_mx = open(decoded_mx_file_name, 'w')
  fout_raw = open('raw.txt', 'w')
  fout_sent = open('sent.txt', 'w')

  for bucket_id in xrange(len(_buckets)):
    bucket_size = len(dev_set[bucket_id])
    offsets = np.arange(0, bucket_size, batch_size) 
    for batch_offset in offsets:
        all_examples = dev_set[bucket_id][batch_offset:batch_offset+batch_size]
        model_dev.batch_size = len(all_examples)        
        token_ids = [x[0] for x in all_examples]
        gold_ids = [x[1] for x in all_examples]
        dec_ids = [[]] * len(token_ids)
        encoder_inputs, decoder_inputs, target_weights = model_dev.get_decode_batch(
                {bucket_id: zip(token_ids, dec_ids)}, bucket_id)
        _, _, output_logits = model_dev.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
        outputs = [np.argmax(logit, axis=1) for logit in output_logits]
        to_decode = np.array(outputs).T
        num_valid = 0
        for sent_id in range(to_decode.shape[0]):
          parse = list(to_decode[sent_id, :])
          if data_utils.EOS_ID in parse:
            parse = parse[:parse.index(data_utils.EOS_ID)]
          # raw decoded parse
          # print(parse)
          decoded_parse = []
          for output in parse:
              if output < len(rev_parse_vocab):
                decoded_parse.append(tf.compat.as_str(rev_parse_vocab[output]))
              else:
                decoded_parse.append("_UNK") 
          # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
          # add brackets for tree balance
          parse_br, valid = add_brackets(decoded_parse)
          num_valid += valid
          # get gold parse, gold sentence
          gold_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id]]
          sent_text = [tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id]]
          # parse with also matching "XX" length
          parse_mx = match_length(parse_br, sent_text)
          parse_mx = delete_empty_constituents(parse_mx)

          to_write_gold = merge_sent_tree(gold_parse[:-1], sent_text) # account for EOS
          to_write_br = merge_sent_tree(parse_br, sent_text)
          to_write_mx = merge_sent_tree(parse_mx, sent_text)

          fout_gold.write('{}\n'.format(' '.join(to_write_gold)))
          fout_br.write('{}\n'.format(' '.join(to_write_br)))
          fout_mx.write('{}\n'.format(' '.join(to_write_mx)))
          fout_sent.write('{}\n'.format(' '.join(sent_text)))
          fout_raw.write('{}\n'.format(' '.join(decoded_parse)))
          
  # Write to file
  fout_gold.close()
  fout_br.close()
  fout_mx.close()  
  fout_sent.close()
  fout_raw.close()
                    decoded_parse.append("_UNK")
            # decoded_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in parse]
            # add brackets for tree balance
            parse_br, valid = add_brackets(decoded_parse)
            num_valid += valid
            # get gold parse, gold sentence
            gold_parse = [
                tf.compat.as_str(rev_parse_vocab[output])
                for output in gold_ids[sent_id]
            ]
            sent_text = [
                tf.compat.as_str(rev_sent_vocab[output])
                for output in token_ids[sent_id]
            ]
            # parse with also matching "XX" length
            parse_mx = match_length(parse_br, sent_text)
            parse_mx = delete_empty_constituents(parse_mx)

            to_write_gold = merge_sent_tree(gold_parse,
                                            sent_text)  # account for EOS
            to_write_br = merge_sent_tree(parse_br, sent_text)
            to_write_mx = merge_sent_tree(parse_mx, sent_text)


def do_evalb(model_dev, sess, dev_set, eval_batch_size):
    gold_file_name = os.path.join(train_dir, 'partial.gold.txt')
    # file with matched brackets
    decoded_br_file_name = os.path.join(train_dir, 'partial.decoded.br.txt')
    # file filler XX help as well
    decoded_mx_file_name = os.path.join(train_dir, 'partial.decoded.mx.txt')
Esempio n. 7
0
    sents_vocab_path)
_, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)

_buckets = [(10, 40), (25, 85), (40, 150)]

gold_file = 'gold_dev.txt'
baseline_file = 'baseline_dev.txt'
fg = open(gold_file, 'w')
fb = open(baseline_file, 'w')

for bucket_id in xrange(len(_buckets)):
    for sentence in dev_set[bucket_id]:
        slength = len(sentence[0])
        toks = sentence[0]
        gold = sentence[1]
        gold_parse = [
            tf.compat.as_str(rev_parse_vocab[output]) for output in gold
        ]
        sent_text = [
            tf.compat.as_str(rev_sent_vocab[output]) for output in toks
        ]
        prediction = random.choice(baseline_dict[slength])
        parse_mx = match_length(prediction.split(), sent_text)
        to_write_mx = merge_sent_tree(parse_mx, sent_text)
        to_write_gold = merge_sent_tree(gold_parse, sent_text)
        fg.write('{}\n'.format(' '.join(to_write_gold)))
        fb.write('{}\n'.format(' '.join(to_write_mx)))

fb.close()
fg.close()