Example #1
0
def _evaluate(eval_fn, input_fn, decode_fn, path, config):
    graph = tf.Graph()
    with graph.as_default():
        features = input_fn()
        refs = features["references"]
        predictions = eval_fn(features)
        results = {"predictions": predictions, "references": refs}

        all_refs = [[] for _ in range(len(refs))]
        all_outputs = []

        sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path,
                                                    config=config)

        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            while not sess.should_stop():
                outputs = sess.run(results)
                # shape: [batch, len]
                predictions = outputs["predictions"].tolist()
                # shape: ([batch, len], ..., [batch, len])
                references = [item.tolist() for item in outputs["references"]]

                all_outputs.extend(predictions)

                for i in range(len(refs)):
                    all_refs[i].extend(references[i])

        decoded_symbols = decode_fn(all_outputs)
        decoded_refs = [decode_fn(refs) for refs in all_refs]
        decoded_refs = [list(x) for x in zip(*decoded_refs)]

        return bleu.bleu(decoded_symbols, decoded_refs)
Example #2
0
def _evaluate(eval_fn, input_fn, decode_fn, path, config):
    graph = tf.Graph()
    with graph.as_default():
        features = input_fn()
        refs = features["references"]
        placeholders = {
            "source": tf.placeholder(tf.int32, [None, None], "source"),
            "source_length": tf.placeholder(tf.int32, [None], "source_length")
        }
        predictions = eval_fn(placeholders)
        predictions = predictions[0][:, 0, :]

        all_refs = [[] for _ in range(len(refs))]
        all_outputs = []

        sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path,
                                                    config=config)

        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            while not sess.should_stop():
                feats = sess.run(features)
                outputs = sess.run(predictions,
                                   feed_dict={
                                       placeholders["source"]:
                                       feats["source"],
                                       placeholders["source_length"]:
                                       feats["source_length"]
                                   })
                # shape: [batch, len]
                outputs = outputs.tolist()
                # shape: ([batch, len], ..., [batch, len])
                references = [item.tolist() for item in feats["references"]]

                all_outputs.extend(outputs)

                for i in range(len(refs)):
                    all_refs[i].extend(references[i])

        decoded_symbols = decode_fn(all_outputs)
        decoded_refs = [decode_fn(refs) for refs in all_refs]
        decoded_refs = [list(x) for x in zip(*decoded_refs)]

        save_path_bpe = os.path.join(path, 'model.ckpt-' + '.pred.bpe')
        save_path_norm = os.path.join(path, 'model.ckpt-' + '.pred.norm')
        with open(save_path_bpe, 'w') as f:
            for sent in decoded_symbols:
                sent = ' '.join(sent) + '\n'
                f.write(sent)
            # Restore from BPE
        cmd = "sed -r 's/(@@ )|(@@ ?$)//g' < %s > %s" % (save_path_bpe,
                                                         save_path_norm)
        os.system(cmd)
        # Reload prediction after restoration
        decoded_symbols = []
        with open(save_path_norm, 'r') as f:
            for line in f:
                sent = line.strip().split()
                decoded_symbols.append(sent)

        return bleu.bleu(decoded_symbols, decoded_refs)
Example #3
0
def _evaluate(eval_fn, input_fn, decode_fn, path, config, device_list):
    graph = tf.Graph()
    with graph.as_default():
        features = input_fn()
        refs = features["references"]
        placeholders = []
        for i in range(len(device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })
            for j in range(100):
                if features.has_key("mt_%d" % j):
                    placeholders[-1]["mt_%d" % j] = tf.placeholder(
                        tf.int32, [None, None], "mt_%d_%d" % (j, i))
                    placeholders[-1]["mt_length_%d" % j] = tf.placeholder(
                        tf.int32, [None], "mt_length_%d_%d" % (j, i))

        predictions = parallel.data_parallelism(device_list, eval_fn,
                                                placeholders)
        predictions = [pred[0][:, 0, :] for pred in predictions]

        all_refs = [[] for _ in range(len(refs))]
        all_outputs = []

        sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path,
                                                    config=config)

        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            while not sess.should_stop():
                feats = sess.run(features)
                inp_feats = {key: feats[key] for key in placeholders[0].keys()}
                op, feed_dict = _shard_features(inp_feats, placeholders,
                                                predictions)
                # A list of numpy array with shape: [batch, len]
                outputs = sess.run(op, feed_dict=feed_dict)

                for shard in outputs:
                    all_outputs.extend(shard.tolist())

                # shape: ([batch, len], ..., [batch, len])
                references = [item.tolist() for item in feats["references"]]

                for i in range(len(refs)):
                    all_refs[i].extend(references[i])

        decoded_symbols = decode_fn(all_outputs)

        for i, l in enumerate(decoded_symbols):
            decoded_symbols[i] = " ".join(l).replace("@@ ", "").split()

        decoded_refs = [decode_fn(refs) for refs in all_refs]
        decoded_refs = [list(x) for x in zip(*decoded_refs)]

        return bleu.bleu(decoded_symbols, decoded_refs)
Example #4
0
def _evaluate(eval_fn, input_fn, decode_fn, path, config):
    graph = tf.Graph()
    with graph.as_default():
        features = input_fn()
        refs = features["references"]
        placeholders = {
            "source":
            tf.placeholder(tf.int32, [None, None], "source"),
            "source_length":
            tf.placeholder(tf.int32, [None], "source_length"),
            "context":
            tf.placeholder(tf.int32, [None, None, None], "context"),
            "context_sen_len":
            tf.placeholder(tf.int32, [None, None], "context_sen_len"),
        }

        predictions = eval_fn(placeholders)
        predictions = predictions[0][:, 0, :]

        all_refs = [[] for _ in range(len(refs))]
        all_outputs = []

        sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path,
                                                    config=config)

        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            while not sess.should_stop():
                feats = sess.run(features)
                outputs = sess.run(predictions,
                                   feed_dict={
                                       placeholders["source"]:
                                       feats["source"],
                                       placeholders["source_length"]:
                                       feats["source_length"],
                                       placeholders["context"]:
                                       feats["context"],
                                       placeholders["context_sen_len"]:
                                       feats["context_sen_len"]
                                   })
                # shape: [batch, len]
                outputs = outputs.tolist()
                # shape: ([batch, len], ..., [batch, len])
                references = [item.tolist() for item in feats["references"]]

                all_outputs.extend(outputs)

                for i in range(len(refs)):
                    all_refs[i].extend(references[i])

        decoded_symbols = decode_fn(all_outputs)
        decoded_refs = [decode_fn(refs) for refs in all_refs]
        decoded_refs = [list(x) for x in zip(*decoded_refs)]

        return bleu.bleu(decoded_symbols, decoded_refs)
Example #5
0
def _evaluate(eval_fn, input_fn, decode_fn, path, config):
    graph = tf.Graph()
    with graph.as_default():
        features = input_fn()
        refs = features["references"]
        predictions = eval_fn(features)
        results = {
            "predictions": predictions,
            "references": refs
        }

        
        # batch_size = refs.get_shape()[0]
        # print(refs)   # Tensor("hash_table_Lookup_2:0", shape=(?, ?), dtype=int64, device=/device:CPU:0)
        # print(batch_size) # ?

        batch_size = 64
        

        # len(refs)  ==> batch_size !!
        all_refs = [[] for _ in range(batch_size)]
        all_outputs = []

        sess_creator = tf.train.ChiefSessionCreator(
            checkpoint_dir=path,
            config=config
        )

        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            while not sess.should_stop():
                outputs = sess.run(results)
                # shape: [batch, len]
                predictions = outputs["predictions"].tolist()
                # shape: ([batch, len], ..., [batch, len])
                references = [item.tolist() for item in outputs["references"]]

                all_outputs.extend(predictions)
                # len(refs)  ==> batch_size !!
                for i in range(batch_size):
                    all_refs[i].extend(references[i])

        decoded_symbols = decode_fn(all_outputs)
        decoded_refs = [decode_fn(refs) for refs in all_refs]
        decoded_refs = [list(x) for x in zip(*decoded_refs)]

        return bleu.bleu(decoded_symbols, decoded_refs)
Example #6
0
new_preds = []
with open(args.pred_file_with_bleu, 'r') as f:
    cand_buf = []
    for idx, line in enumerate(f):
        cand_buf.append(line.strip())
        if (idx + 1) % args.beam_size == 0:
            str_lst = [line.split("|||")[0] for line in cand_buf]
            bleu_lst = [
                float(line.split("|||")[1].strip("|")) for line in cand_buf
            ]
            str_bleu_lst = zip(str_lst, bleu_lst)
            str_bleu_lst = sorted(str_bleu_lst,
                                  key=lambda t: t[1],
                                  reverse=True)
            new_preds.append(str_bleu_lst[0][0])
            cand_buf = []

with open(args.new_pred_file, 'w') as f:
    for pred in new_preds:
        f.write(pred + '\n')

golds = []
with open(args.refs_file, 'r') as f:
    gold_lines = f.readlines()
    golds = [line.strip().split() for line in gold_lines]
    golds = [[gold] for gold in golds]
    preds = [pred.split() for pred in new_preds]
    bleu_score_corpus = bleu.bleu(preds, golds)

print("BLEU score: %f" % bleu_score_corpus)
Example #7
0
def _evaluate_model(model, sorted_key, dataset, references, params):
    # Create model
    with torch.no_grad():
        model.eval()

        iterator = iter(dataset)
        counter = 0
        pad_max = 1024

        # count eval dataset
        total_len = 0
        for _ in iterator:
            total_len += 1
        iterator = iter(dataset)

        # Buffers for synchronization
        size = torch.zeros([dist.get_world_size()]).long()
        t_list = [
            torch.empty([params.decode_batch_size, pad_max]).long()
            for _ in range(dist.get_world_size())
        ]
        results = []

        if dist.get_rank() == 0:
            pbar = tqdm(total=total_len)
            pbar.set_description("Validating model")

        while True:
            try:
                features = next(iterator)
                features = lookup(features, "infer", params)
                batch_size = features["source"].shape[0]
            except:
                features = {
                    "source": torch.ones([1, 1]).long(),
                    "source_mask": torch.ones([1, 1]).float()
                }
                batch_size = 0

            counter += 1

            # Decode
            seqs, _ = beam_search([model], features, params)

            # Padding
            seqs = torch.squeeze(seqs, dim=1)
            pad_batch = params.decode_batch_size - seqs.shape[0]
            pad_length = pad_max - seqs.shape[1]
            seqs = torch.nn.functional.pad(seqs, (0, pad_length, 0, pad_batch))

            # Synchronization
            size.zero_()
            size[dist.get_rank()].copy_(torch.tensor(batch_size))
            dist.all_reduce(size)
            dist.all_gather(t_list, seqs)

            if size.sum() == 0:
                break

            if dist.get_rank() != 0:
                continue

            for i in range(params.decode_batch_size):
                for j in range(dist.get_world_size()):
                    n = size[j]
                    seq = _convert_to_string(t_list[j][i], params)

                    if i >= n:
                        continue

                    # Restore BPE segmentation
                    seq = BPE.decode(seq)

                    results.append(seq.split())

            if dist.get_rank() == 0:
                pbar.update(1)

    model.train()

    if dist.get_rank() == 0:
        pbar.close()
        restored_results = []

        for idx in range(len(results)):
            restored_results.append(results[sorted_key[idx]])

        return bleu(restored_results, references)

    return 0.0
Example #8
0
import thumt.utils.bleu as bleu
import argparse

parser = argparse.ArgumentParser("Compute sentence bleu.")
parser.add_argument("-pred_path", type=str, required=True)
parser.add_argument("-beam_size", type=int, required=True)
parser.add_argument("-refer_path", type=str, required=True)

args = parser.parse_args()

with open(args.pred_path, 'r') as f:
    preds = f.readlines()

with open(args.refer_path, 'r') as f:
    golds = f.readlines()

f_summary = open(args.pred_path + ".sent-bleu", 'w')
for idx, pred in enumerate(preds):
    gold_idx = idx / args.beam_size
    gold = golds[gold_idx].strip()  # remove `\n`
    #refs = [gold.split()]
    refs = [[gold.split()]]
    pred = [pred.strip().split()]
    #import ipdb; ipdb.set_trace()
    sent_bleu = bleu.bleu(pred, refs, smooth=True)
    print("%s : %s : %f" % (pred, refs, sent_bleu))
    f_summary.write(" ".join(pred[0]) + "|||" + str(sent_bleu) + "\n")
f_summary.close()