def _evaluate(eval_fn, input_fn, decode_fn, path, config): graph = tf.Graph() with graph.as_default(): features = input_fn() refs = features["references"] predictions = eval_fn(features) results = {"predictions": predictions, "references": refs} all_refs = [[] for _ in range(len(refs))] all_outputs = [] sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path, config=config) with tf.train.MonitoredSession(session_creator=sess_creator) as sess: while not sess.should_stop(): outputs = sess.run(results) # shape: [batch, len] predictions = outputs["predictions"].tolist() # shape: ([batch, len], ..., [batch, len]) references = [item.tolist() for item in outputs["references"]] all_outputs.extend(predictions) for i in range(len(refs)): all_refs[i].extend(references[i]) decoded_symbols = decode_fn(all_outputs) decoded_refs = [decode_fn(refs) for refs in all_refs] decoded_refs = [list(x) for x in zip(*decoded_refs)] return bleu.bleu(decoded_symbols, decoded_refs)
def _evaluate(eval_fn, input_fn, decode_fn, path, config): graph = tf.Graph() with graph.as_default(): features = input_fn() refs = features["references"] placeholders = { "source": tf.placeholder(tf.int32, [None, None], "source"), "source_length": tf.placeholder(tf.int32, [None], "source_length") } predictions = eval_fn(placeholders) predictions = predictions[0][:, 0, :] all_refs = [[] for _ in range(len(refs))] all_outputs = [] sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path, config=config) with tf.train.MonitoredSession(session_creator=sess_creator) as sess: while not sess.should_stop(): feats = sess.run(features) outputs = sess.run(predictions, feed_dict={ placeholders["source"]: feats["source"], placeholders["source_length"]: feats["source_length"] }) # shape: [batch, len] outputs = outputs.tolist() # shape: ([batch, len], ..., [batch, len]) references = [item.tolist() for item in feats["references"]] all_outputs.extend(outputs) for i in range(len(refs)): all_refs[i].extend(references[i]) decoded_symbols = decode_fn(all_outputs) decoded_refs = [decode_fn(refs) for refs in all_refs] decoded_refs = [list(x) for x in zip(*decoded_refs)] save_path_bpe = os.path.join(path, 'model.ckpt-' + '.pred.bpe') save_path_norm = os.path.join(path, 'model.ckpt-' + '.pred.norm') with open(save_path_bpe, 'w') as f: for sent in decoded_symbols: sent = ' '.join(sent) + '\n' f.write(sent) # Restore from BPE cmd = "sed -r 's/(@@ )|(@@ ?$)//g' < %s > %s" % (save_path_bpe, save_path_norm) os.system(cmd) # Reload prediction after restoration decoded_symbols = [] with open(save_path_norm, 'r') as f: for line in f: sent = line.strip().split() decoded_symbols.append(sent) return bleu.bleu(decoded_symbols, decoded_refs)
def _evaluate(eval_fn, input_fn, decode_fn, path, config, device_list): graph = tf.Graph() with graph.as_default(): features = input_fn() refs = features["references"] placeholders = [] for i in range(len(device_list)): placeholders.append({ "source": tf.placeholder(tf.int32, [None, None], "source_%d" % i), "source_length": tf.placeholder(tf.int32, [None], "source_length_%d" % i) }) for j in range(100): if features.has_key("mt_%d" % j): placeholders[-1]["mt_%d" % j] = tf.placeholder( tf.int32, [None, None], "mt_%d_%d" % (j, i)) placeholders[-1]["mt_length_%d" % j] = tf.placeholder( tf.int32, [None], "mt_length_%d_%d" % (j, i)) predictions = parallel.data_parallelism(device_list, eval_fn, placeholders) predictions = [pred[0][:, 0, :] for pred in predictions] all_refs = [[] for _ in range(len(refs))] all_outputs = [] sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path, config=config) with tf.train.MonitoredSession(session_creator=sess_creator) as sess: while not sess.should_stop(): feats = sess.run(features) inp_feats = {key: feats[key] for key in placeholders[0].keys()} op, feed_dict = _shard_features(inp_feats, placeholders, predictions) # A list of numpy array with shape: [batch, len] outputs = sess.run(op, feed_dict=feed_dict) for shard in outputs: all_outputs.extend(shard.tolist()) # shape: ([batch, len], ..., [batch, len]) references = [item.tolist() for item in feats["references"]] for i in range(len(refs)): all_refs[i].extend(references[i]) decoded_symbols = decode_fn(all_outputs) for i, l in enumerate(decoded_symbols): decoded_symbols[i] = " ".join(l).replace("@@ ", "").split() decoded_refs = [decode_fn(refs) for refs in all_refs] decoded_refs = [list(x) for x in zip(*decoded_refs)] return bleu.bleu(decoded_symbols, decoded_refs)
def _evaluate(eval_fn, input_fn, decode_fn, path, config): graph = tf.Graph() with graph.as_default(): features = input_fn() refs = features["references"] placeholders = { "source": tf.placeholder(tf.int32, [None, None], "source"), "source_length": tf.placeholder(tf.int32, [None], "source_length"), "context": tf.placeholder(tf.int32, [None, None, None], "context"), "context_sen_len": tf.placeholder(tf.int32, [None, None], "context_sen_len"), } predictions = eval_fn(placeholders) predictions = predictions[0][:, 0, :] all_refs = [[] for _ in range(len(refs))] all_outputs = [] sess_creator = tf.train.ChiefSessionCreator(checkpoint_dir=path, config=config) with tf.train.MonitoredSession(session_creator=sess_creator) as sess: while not sess.should_stop(): feats = sess.run(features) outputs = sess.run(predictions, feed_dict={ placeholders["source"]: feats["source"], placeholders["source_length"]: feats["source_length"], placeholders["context"]: feats["context"], placeholders["context_sen_len"]: feats["context_sen_len"] }) # shape: [batch, len] outputs = outputs.tolist() # shape: ([batch, len], ..., [batch, len]) references = [item.tolist() for item in feats["references"]] all_outputs.extend(outputs) for i in range(len(refs)): all_refs[i].extend(references[i]) decoded_symbols = decode_fn(all_outputs) decoded_refs = [decode_fn(refs) for refs in all_refs] decoded_refs = [list(x) for x in zip(*decoded_refs)] return bleu.bleu(decoded_symbols, decoded_refs)
def _evaluate(eval_fn, input_fn, decode_fn, path, config): graph = tf.Graph() with graph.as_default(): features = input_fn() refs = features["references"] predictions = eval_fn(features) results = { "predictions": predictions, "references": refs } # batch_size = refs.get_shape()[0] # print(refs) # Tensor("hash_table_Lookup_2:0", shape=(?, ?), dtype=int64, device=/device:CPU:0) # print(batch_size) # ? batch_size = 64 # len(refs) ==> batch_size !! all_refs = [[] for _ in range(batch_size)] all_outputs = [] sess_creator = tf.train.ChiefSessionCreator( checkpoint_dir=path, config=config ) with tf.train.MonitoredSession(session_creator=sess_creator) as sess: while not sess.should_stop(): outputs = sess.run(results) # shape: [batch, len] predictions = outputs["predictions"].tolist() # shape: ([batch, len], ..., [batch, len]) references = [item.tolist() for item in outputs["references"]] all_outputs.extend(predictions) # len(refs) ==> batch_size !! for i in range(batch_size): all_refs[i].extend(references[i]) decoded_symbols = decode_fn(all_outputs) decoded_refs = [decode_fn(refs) for refs in all_refs] decoded_refs = [list(x) for x in zip(*decoded_refs)] return bleu.bleu(decoded_symbols, decoded_refs)
new_preds = [] with open(args.pred_file_with_bleu, 'r') as f: cand_buf = [] for idx, line in enumerate(f): cand_buf.append(line.strip()) if (idx + 1) % args.beam_size == 0: str_lst = [line.split("|||")[0] for line in cand_buf] bleu_lst = [ float(line.split("|||")[1].strip("|")) for line in cand_buf ] str_bleu_lst = zip(str_lst, bleu_lst) str_bleu_lst = sorted(str_bleu_lst, key=lambda t: t[1], reverse=True) new_preds.append(str_bleu_lst[0][0]) cand_buf = [] with open(args.new_pred_file, 'w') as f: for pred in new_preds: f.write(pred + '\n') golds = [] with open(args.refs_file, 'r') as f: gold_lines = f.readlines() golds = [line.strip().split() for line in gold_lines] golds = [[gold] for gold in golds] preds = [pred.split() for pred in new_preds] bleu_score_corpus = bleu.bleu(preds, golds) print("BLEU score: %f" % bleu_score_corpus)
def _evaluate_model(model, sorted_key, dataset, references, params): # Create model with torch.no_grad(): model.eval() iterator = iter(dataset) counter = 0 pad_max = 1024 # count eval dataset total_len = 0 for _ in iterator: total_len += 1 iterator = iter(dataset) # Buffers for synchronization size = torch.zeros([dist.get_world_size()]).long() t_list = [ torch.empty([params.decode_batch_size, pad_max]).long() for _ in range(dist.get_world_size()) ] results = [] if dist.get_rank() == 0: pbar = tqdm(total=total_len) pbar.set_description("Validating model") while True: try: features = next(iterator) features = lookup(features, "infer", params) batch_size = features["source"].shape[0] except: features = { "source": torch.ones([1, 1]).long(), "source_mask": torch.ones([1, 1]).float() } batch_size = 0 counter += 1 # Decode seqs, _ = beam_search([model], features, params) # Padding seqs = torch.squeeze(seqs, dim=1) pad_batch = params.decode_batch_size - seqs.shape[0] pad_length = pad_max - seqs.shape[1] seqs = torch.nn.functional.pad(seqs, (0, pad_length, 0, pad_batch)) # Synchronization size.zero_() size[dist.get_rank()].copy_(torch.tensor(batch_size)) dist.all_reduce(size) dist.all_gather(t_list, seqs) if size.sum() == 0: break if dist.get_rank() != 0: continue for i in range(params.decode_batch_size): for j in range(dist.get_world_size()): n = size[j] seq = _convert_to_string(t_list[j][i], params) if i >= n: continue # Restore BPE segmentation seq = BPE.decode(seq) results.append(seq.split()) if dist.get_rank() == 0: pbar.update(1) model.train() if dist.get_rank() == 0: pbar.close() restored_results = [] for idx in range(len(results)): restored_results.append(results[sorted_key[idx]]) return bleu(restored_results, references) return 0.0
import thumt.utils.bleu as bleu import argparse parser = argparse.ArgumentParser("Compute sentence bleu.") parser.add_argument("-pred_path", type=str, required=True) parser.add_argument("-beam_size", type=int, required=True) parser.add_argument("-refer_path", type=str, required=True) args = parser.parse_args() with open(args.pred_path, 'r') as f: preds = f.readlines() with open(args.refer_path, 'r') as f: golds = f.readlines() f_summary = open(args.pred_path + ".sent-bleu", 'w') for idx, pred in enumerate(preds): gold_idx = idx / args.beam_size gold = golds[gold_idx].strip() # remove `\n` #refs = [gold.split()] refs = [[gold.split()]] pred = [pred.strip().split()] #import ipdb; ipdb.set_trace() sent_bleu = bleu.bleu(pred, refs, smooth=True) print("%s : %s : %f" % (pred, refs, sent_bleu)) f_summary.write(" ".join(pred[0]) + "|||" + str(sent_bleu) + "\n") f_summary.close()