def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [models.get_model(model) for model in args.models] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints[i], args.models[i], params_list[i]) for i in range(len(args.checkpoints)) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate(args.checkpoints): tf.logging.info("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith(model_cls_list[i].get_name()): continue if name.find("losses_avg") >= 0: continue tensor = reader.get_tensor(name) values[name] = tensor model_var_lists.append(values) # Build models model_list = [] for i in range(len(args.checkpoints)): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_list.append(model) params = params_list[0] placeholder = {} placeholder["source"] = tf.placeholder(tf.int32, [None, None], "source") placeholder["source_length"] = tf.placeholder(tf.int32, [None], "source_length") enc_fn, dec_fn = model_list[0].get_inference_func() enc = enc_fn(placeholder, params) state = {} state["encoder"] = tf.placeholder(tf.float32, [None, None, params.hidden_size], "encoder") dec = dec_fn(placeholder, state, params) # Read input file sorted_keys, sorted_inputs = dataset.sort_input_file(args.input) # Build input queue features = dataset.get_inference_input(sorted_inputs, params) # Create placeholders placeholders = [] for i in range(len(params.device_list)): placeholders.append({ "source": tf.placeholder(tf.int32, [None, None], "source_%d" % i), "source_length": tf.placeholder(tf.int32, [None], "source_length_%d" % i) }) # A list of outputs if params.generate_samples: inference_fn = sampling.create_sampling_graph else: inference_fn = inference.create_inference_graph # Create assign ops assign_ops = [] feed_dict = {} all_var_list = tf.trainable_variables() for i in range(len(args.checkpoints)): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i, feed_dict) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) init_op = tf.tables_initializer() results = [] tf.get_default_graph().finalize() # Create session with tf.Session(config=session_config(params)) as sess: # Restore variables sess.run(assign_op, feed_dict=feed_dict) sess.run(init_op) total_start = time.time() while True: start = time.time() try: feats = sess.run(features) feed_dict = { placeholder["source"]: feats["source"], placeholder["source_length"]: feats["source_length"] } encoder_output = sess.run(enc, feed_dict=feed_dict) encoder_output = encoder_output['encoder'] feed_dict_dec = { placeholder["source"]: feats["source"], placeholder["source_length"]: feats["source_length"], state["encoder"]: encoder_output } result = sess.run(dec, feed_dict=feed_dict_dec) #print(result) results.append(result) message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) end = time.time() print('time:', end - start, 's') except tf.errors.OutOfRangeError: break total_end = time.time() print('total time:', total_end - total_start, 's') # Convert to plain text vocab = params.vocabulary["target"] outputs = [] scores = [] for result in results: print('result', result) for item in result: outputs.append(item.tolist()) #for item in result[1]: # scores.append(item.tolist()) #outputs = list(itertools.chain(*outputs)) restored_inputs = [] restored_outputs = [] for index in range(len(sorted_inputs)): restored_inputs.append(sorted_inputs[sorted_keys[index]]) restored_outputs.append(outputs[sorted_keys[index]]) # Write to file with open(args.output, "w") as outfile: count = 0 for outputs in restored_outputs: print('oup', outputs) for output in outputs: outfile.write(str(round(output, 2)) + ' ') outfile.write('\n') for output, score in zip(outputs, scores): decoded = [] decoded = " ".join(decoded) outfile.write("%s\n" % decoded) count += 1
def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [models.get_model(model) for model in args.models] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints[i], args.models[i], params_list[i]) for i in range(len(args.checkpoints)) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate(args.checkpoints): print("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith(model_cls_list[i].get_name()): continue if name.find("losses_avg") >= 0: continue tensor = reader.get_tensor(name) values[name] = tensor model_var_lists.append(values) # Build models model_fns = [] for i in range(len(args.checkpoints)): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_fn = model.get_inference_func() model_fns.append(model_fn) params = params_list[0] #features = dataset.get_inference_input_with_bert(args.input, params) if params.use_bert and params.bert_emb_path: features = ds.get_inference_input_with_bert( params.input + [params.bert_emb_path], params) else: features = ds.get_inference_input(params.input, params) predictions = search.create_inference_graph(model_fns, features, params) assign_ops = [] all_var_list = tf.trainable_variables() for i in range(len(args.checkpoints)): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) sess_creator = tf.train.ChiefSessionCreator( config=session_config(params)) results = [] # Create session with tf.train.MonitoredSession(session_creator=sess_creator) as sess: # Restore variables sess.run(assign_op) while not sess.should_stop(): results.append(sess.run(predictions)) message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) if len(results) > 2: break # Convert to plain text vocab = params.vocabulary["target"] outputs = [] for result in results: outputs.append(result.tolist()) outputs = list(itertools.chain(*outputs)) #restored_outputs = [] # Write to file with open(args.output, "w") as outfile: for output in outputs: decoded = [] for idx in output: #if idx == params.mapping["target"][params.eos]: #if idx != output[-1]: #print("Warning: incomplete predictions as {}".format(" ".join(output))) #break decoded.append(vocab[idx]) decoded = " ".join(decoded) outfile.write("%s\n" % decoded)
def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [models.get_model(model) for model in args.models] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints[i], args.models[i], params_list[i]) for i in range(len(args.checkpoints)) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate(args.checkpoints): tf.logging.info("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith(model_cls_list[i].get_name()): continue if name.find("losses_avg") >= 0: continue tensor = reader.get_tensor(name) values[name] = tensor model_var_lists.append(values) # Build models model_list = [] for i in range(len(args.checkpoints)): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_list.append(model) params = params_list[0] # Read input file sorted_keys, sorted_inputs = dataset.sort_input_file(args.input) # Build input queue features = dataset.get_inference_input(sorted_inputs, params) # Create placeholders placeholders = [] for i in range(len(params.device_list)): placeholders.append({ "source": tf.placeholder(tf.int32, [None, None], "source_%d" % i), "source_length": tf.placeholder(tf.int32, [None], "source_length_%d" % i) }) # A list of outputs if params.generate_samples: inference_fn = sampling.create_sampling_graph else: inference_fn = inference.create_inference_graph predictions = parallel.data_parallelism( params.device_list, lambda f: inference_fn(model_list, f, params), placeholders) # Create assign ops assign_ops = [] feed_dict = {} all_var_list = tf.trainable_variables() for i in range(len(args.checkpoints)): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i, feed_dict) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) init_op = tf.tables_initializer() results = [] tf.get_default_graph().finalize() # Create session with tf.Session(config=session_config(params)) as sess: # Restore variables sess.run(assign_op, feed_dict=feed_dict) sess.run(init_op) while True: try: feats = sess.run(features) op, feed_dict = shard_features(feats, placeholders, predictions) results.append(sess.run(op, feed_dict=feed_dict)) message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) except tf.errors.OutOfRangeError: break # Convert to plain text vocab = params.vocabulary["target"] outputs = [] scores = [] for result in results: for item in result[0]: outputs.append(item.tolist()) for item in result[1]: scores.append(item.tolist()) outputs = list(itertools.chain(*outputs)) scores = list(itertools.chain(*scores)) restored_inputs = [] restored_outputs = [] restored_scores = [] for index in range(len(sorted_inputs)): restored_inputs.append(sorted_inputs[sorted_keys[index]]) restored_outputs.append(outputs[sorted_keys[index]]) restored_scores.append(scores[sorted_keys[index]]) # Write to file with open(args.output, "w") as outfile: count = 0 for outputs, scores in zip(restored_outputs, restored_scores): for output, score in zip(outputs, scores): decoded = [] for idx in output: if idx == params.mapping["target"][params.eos]: break decoded.append(vocab[idx]) decoded = " ".join(decoded) if not args.verbose: outfile.write("%s\n" % decoded) #break else: pattern = "%d ||| %s ||| %s ||| %f\n" source = restored_inputs[count] values = (count, source, decoded, score) outfile.write(pattern % values) count += 1
def build_graph(params, args, model_list, model_cls_list, model_var_lists, problem=None): if problem == "parsing": fo = args.parsing_output fi = args.parsing_input elif problem == "amr": fo = args.amr_outpu fi = args.amr_input else: print("problem only in parsing or amr") # Read input file sorted_keys, sorted_inputs = dataset.sort_input_file(fi) # Build input queue features = dataset.get_inference_input(sorted_inputs, params) # only source data # Create placeholders placeholders = [] for i in range(len(params.device_list)): placeholders.append({ "source": tf.placeholder(tf.int32, [None, None], "source_%d" % i), "source_length": tf.placeholder(tf.int32, [None], "source_length_%d" % i) }) # A list of outputs if params.generate_samples: inference_fn = sampling.create_sampling_graph else: inference_fn = inference.create_inference_graph predictions = parallel.data_parallelism( params.device_list, lambda f: inference_fn(model_list, f, params, problem=problem), placeholders) # Create assign ops assign_ops = [] feed_dict = {} all_var_list = tf.trainable_variables() for i in range(len(args.checkpoints)): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i, feed_dict) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) init_op = tf.tables_initializer() results = [] tf.get_default_graph().finalize() # Create session with tf.Session(config=session_config(params)) as sess: # Restore variables sess.run(assign_op, feed_dict=feed_dict) sess.run(init_op) while True: try: feats = sess.run(features) op, feed_dict = shard_features(feats, placeholders, predictions) results.append(sess.run(op, feed_dict=feed_dict)) message = "Finished %s batch %d" % (len(results), problem) tf.logging.log(tf.logging.INFO, message) except tf.errors.OutOfRangeError: break # Convert to plain text vocab = params.vocabulary[problem+"_target"] outputs = [] scores = [] for result in results: for item in result[0]: outputs.append(item.tolist()) for item in result[1]: scores.append(item.tolist()) outputs = list(itertools.chain(*outputs)) scores = list(itertools.chain(*scores)) restored_inputs = [] restored_outputs = [] restored_scores = [] for index in range(len(sorted_inputs)): restored_inputs.append(sorted_inputs[sorted_keys[index]]) restored_outputs.append(outputs[sorted_keys[index]]) restored_scores.append(scores[sorted_keys[index]]) # Write to file with open(fo, "w") as outfile: count = 0 for outputs, scores in zip(restored_outputs, restored_scores): for output, score in zip(outputs, scores): decoded = [] for idx in output: if idx == params.mapping["target"][params.eos]: break decoded.append(vocab[idx]) decoded = " ".join(decoded) if not args.verbose: outfile.write("%s\n" % decoded) break else: pattern = "%d ||| %s ||| %s ||| %f\n" source = restored_inputs[count] values = (count, source, decoded, score) outfile.write(pattern % values) count += 1
def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [models.get_model(args.models)] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints, args.models, params_list[0]) #导入训练产生的配置文件 #for i in range(len([args.checkpoints])) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate([args.checkpoints]): print("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) #所有模型变量取成列表 values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith(model_cls_list[i].get_name() ): #获取所有rnnsearch里不带"losses_avg"的变量 continue if name.find("losses_avg") >= 0: continue tensor = reader.get_tensor(name) #获取成数 values[name] = tensor model_var_lists.append(values) #获取所有rnnsearch里不带"losses_avg"的变量,数值 # Build models model_fns = [] for i in range(len([args.checkpoints])): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_fn = model.get_inference_func() #调用模型中的推理功能 model_fns.append(model_fn) params = params_list[0] #features = dataset.get_inference_input_with_bert(args.input, params) if params.use_bert and params.bert_emb_path: features = dataset.get_inference_input_with_bert( params.input + [params.bert_emb_path], params) else: features = dataset.get_inference_input([params.input], params) predictions = search.create_inference_graph(model_fns, features, params) assign_ops = [] all_var_list = tf.trainable_variables() for i in range(len([args.checkpoints])): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) sess_creator = tf.train.ChiefSessionCreator( config=session_config(params)) result_for_score = [] result_for_write = [] # Create session with tf.train.MonitoredSession(session_creator=sess_creator) as sess: # Restore variables sess.run(assign_op) lenth = [] with open(args.input, "r", encoding="utf8") as f: for line in f: if line.strip() == "-DOCSTART-\n": continue lines = line.strip().split(" ") lenth.append(len(lines)) f.close() #获取每句话的长度,为去掉padding的字做参考 current_num = 0 batch = 0 while not sess.should_stop(): currrent_res_arr = sess.run(predictions) result_for_write.append(currrent_res_arr) for arr in currrent_res_arr: result_for_score.extend(list(arr)[:lenth[current_num]]) current_num += 1 batch += 1 message = "Finished batch %d" % batch tf.logging.log(tf.logging.INFO, message) if params.is_validation: from sklearn.metrics import precision_score, recall_score, f1_score import numpy as np #将标签映射成序号 voc_lis = params.vocabulary["target"] index = list(np.arange(len(voc_lis))) dic = dict(zip(voc_lis, index)) def map_res(x): return dic[x] targets_list = [] with open(args.eval_file, "r") as f: #读取标签文件 for line in f: if line.strip() == "O": continue lines = line.strip().split(" ") targets_list.extend(list(map(map_res, lines))) #标签文件转化成序号 result_arr = np.array(result_for_score) targets_arr = np.array(targets_list) precision_ = precision_score(targets_arr, result_arr, average="micro", labels=[0, 2, 3, 4]) recall_ = recall_score(result_arr, targets_arr, average="micro", labels=[0, 2, 3, 4]) print("precision_score:{}".format(precision_)) print("recall_score:{}".format(recall_)) print("F1_score:{}".format(2 * precision_ * recall_ / (recall_ + precision_))) else: # Convert to plain text vocab = params.vocabulary["target"] outputs = [] for result in result_for_write: outputs.append(result.tolist()) outputs = list(itertools.chain(*outputs)) #restored_outputs = [] # Write to file num = 0 with open(args.output, "w") as outfile: for output in outputs: decoded = [] for idx in output[:lenth[num] + 1]: if idx == params.mapping["target"][params.eos]: if idx != output[lenth[num]]: print( "Warning: incomplete predictions as line{} in src sentence" .format(num + 1)) decoded.append(vocab[idx]) decoded = " ".join(decoded[:-1]) outfile.write("%s\n" % decoded) num += 1
def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [models.get_model(args.models)] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints, args.models, params_list[0]) #导入训练产生的配置文件 #for i in range(len([args.checkpoints])) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate([args.checkpoints]): print("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) #所有模型变量取成列表 values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith(model_cls_list[i].get_name() ): #获取所有rnnsearch里不带"losses_avg"的变量 continue if name.find("losses_avg") >= 0: continue tensor = reader.get_tensor(name) #获取成数 values[name] = tensor model_var_lists.append(values) #获取所有rnnsearch里不带"losses_avg"的变量,数值 # Build models model_fns = [] for i in range(len([args.checkpoints])): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_fn = model.get_inference_func() #调用模型中的推理功能 model_fns.append(model_fn) params = params_list[0] #features = dataset.get_inference_input_with_bert(args.input, params) if params.use_bert and params.bert_emb_path: features = dataset.get_inference_input_with_bert( params.input + [params.bert_emb_path], params) else: features = dataset.get_inference_input([params.input], params) predictions = search.create_inference_graph(model_fns, features, params) assign_ops = [] all_var_list = tf.trainable_variables() for i in range(len([args.checkpoints])): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) sess_creator = tf.train.ChiefSessionCreator( config=session_config(params)) results = [] # Create session with tf.train.MonitoredSession(session_creator=sess_creator) as sess: # Restore variables sess.run(assign_op) while not sess.should_stop(): results.extend(sess.run(predictions)) message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) tar = [] with open(params.input, "r") as inputs_f: for line in inputs_f: if line.strip() == "O": continue else: tar.extend(line.split(" ")[:-1])
def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [models.get_model(model) for model in args.models] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints[i], args.models[i], params_list[i]) for i in range(len(args.checkpoints)) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate(args.checkpoints): tf.logging.info("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith(model_cls_list[i].get_name()): continue if name.find("losses_avg") >= 0: continue tensor = reader.get_tensor(name) values[name] = tensor model_var_lists.append(values) # Build models model_list = [] for i in range(len(args.checkpoints)): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_list.append(model) params = params_list[0] # Read input file sorted_keys, sorted_inputs = dataset.sort_input_file(args.input) # Build input queue features = dataset.get_inference_input(sorted_inputs, params) # Create placeholders placeholders = [] for i in range(len(params.device_list)): placeholders.append({ "source": tf.placeholder(tf.int32, [None, None], "source_%d" % i), "source_length": tf.placeholder(tf.int32, [None], "source_length_%d" % i) }) # A list of outputs if params.generate_samples: inference_fn = sampling.create_sampling_graph else: inference_fn = inference.create_inference_graph predictions = parallel.data_parallelism( params.device_list, lambda f: inference_fn(model_list, f, params), placeholders) # Create assign ops assign_ops = [] feed_dict = {} all_var_list = tf.trainable_variables() for i in range(len(args.checkpoints)): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i, feed_dict) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) init_op = tf.tables_initializer() results = [] tf.get_default_graph().finalize() tf.logging.info(args.models[0]) if args.models[0] == 'transformer_raw_t5': t5_list = [] for var in tf.trainable_variables(): if 'en_t5_bias_mat' in var.name or 'de_self_relative_attention_bias' in var.name: t5_list.append(var) tf.logging.info(var) for op in tf.get_default_graph().get_operations(): if 'encoder_t5_bias' in op.name or 'decoder_t5_bias' in op.name: if 'random' in op.name or 'read' in op.name or 'Assign' in op.name or 'placeholder' in op.name: continue t5_list.append(op.values()[0]) tf.logging.info(op.values()[0].name) elif args.models[0] == 'transformer_raw_soft_t5': soft_t5_bias_list = [] for op in tf.get_default_graph().get_operations(): if 'soft_t5_bias' in op.name or 'soft_t5_encoder' in op.name or 'soft_t5_decoder' in op.name: if 'random' in op.name or 'read' in op.name or 'Assign' in op.name or 'placeholder' in op.name or 'decoder' in op.name: continue soft_t5_bias_list.append(op.values()[0]) tf.logging.info(op.values()[0].name) # Create session with tf.Session(config=session_config(params)) as sess: # Restore variables sess.run(assign_op, feed_dict=feed_dict) sess.run(init_op) while True: try: feats = sess.run(features) op, feed_dict = shard_features(feats, placeholders, predictions) results.append(sess.run(op, feed_dict=feed_dict)) ''' if args.models[0] == 'transformer_raw_t5': var_en_bucket=tf.get_default_graph().get_tensor_by_name(t5_list[0].name) var_de_bucket=tf.get_default_graph().get_tensor_by_name(t5_list[1].name) var_en_bias=tf.get_default_graph().get_tensor_by_name(t5_list[2].name) en_bucket,de_bucket,en_t5_bias = sess.run([var_en_bucket, var_de_bucket, var_en_bias], feed_dict=feed_dict) ret_param = {'en_bucket':en_bucket,'de_bucket':en_bucket, 'en_t5_bias':en_t5_bias} pickle.dump(ret_param,open(args.checkpoints[0]+'/'+'t5_bias.pkl','wb')) tf.logging.info('store the t5 bias') elif args.models[0] == 'transformer_raw_soft_t5': var_en_alpha=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[0].name) var_en_beta=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[1].name) var_en_t5_bias=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[2].name) en_alpha,en_beta,en_t5_bias = sess.run([var_en_alpha,var_en_beta,var_en_t5_bias], feed_dict=feed_dict) ret_param = {'en_t5_bias':en_t5_bias,'en_alpha':en_alpha, 'en_beta':en_beta} pickle.dump(ret_param,open(args.checkpoints[0]+'/'+'soft_t5_bias.pkl','wb')) tf.logging.info('store the soft-t5 bias') ''' message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) except tf.errors.OutOfRangeError: break # Convert to plain text vocab = params.vocabulary["target"] outputs = [] scores = [] for result in results: for shard in result: for item in shard[0]: outputs.append(item.tolist()) for item in shard[1]: scores.append(item.tolist()) restored_inputs = [] restored_outputs = [] restored_scores = [] for index in range(len(sorted_inputs)): restored_inputs.append(sorted_inputs[sorted_keys[index]]) restored_outputs.append(outputs[sorted_keys[index]]) restored_scores.append(scores[sorted_keys[index]]) # Write to file if sys.version_info.major == 2: outfile = open(args.output, "w") elif sys.version_info.major == 3: outfile = open(args.output, "w", encoding="utf-8") else: raise ValueError("Unkown python running environment!") count = 0 for outputs, scores in zip(restored_outputs, restored_scores): for output, score in zip(outputs, scores): decoded = [] for idx in output: if idx == params.mapping["target"][params.eos]: break decoded.append(vocab[idx]) decoded = " ".join(decoded) if not args.verbose: outfile.write("%s\n" % decoded) else: pattern = "%d ||| %s ||| %s ||| %f\n" source = restored_inputs[count] values = (count, source, decoded, score) outfile.write(pattern % values) count += 1 outfile.close()
def main(args): tf.logging.set_verbosity(tf.logging.INFO) # Load configs model_cls_list = [models.get_model(model) for model in args.models] params_list = [default_parameters() for _ in range(len(model_cls_list))] params_list = [ merge_parameters(params, model_cls.get_parameters()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints[i], args.models[i], params_list[i]) for i in range(len(args.checkpoints)) ] params_list = [ override_parameters(params_list[i], args) for i in range(len(model_cls_list)) ] # Build Graph with tf.Graph().as_default(): model_var_lists = [] # Load checkpoints for i, checkpoint in enumerate(args.checkpoints): tf.logging.info("Loading %s" % checkpoint) var_list = tf.train.list_variables(checkpoint) values = {} reader = tf.train.load_checkpoint(checkpoint) for (name, shape) in var_list: if not name.startswith(model_cls_list[i].get_name()): continue if name.find("losses_avg") >= 0: continue tensor = reader.get_tensor(name) values[name] = tensor model_var_lists.append(values) # Build models model_fns = [] for i in range(len(args.checkpoints)): name = model_cls_list[i].get_name() model = model_cls_list[i](params_list[i], name + "_%d" % i) model_fn = model.get_inference_func() model_fns.append(model_fn) params = params_list[0] # Read input file sorted_keys, sorted_inputs = dataset.sort_input_file(args.input) # Build input queue features = dataset.get_inference_input(sorted_inputs, params) predictions = search.create_inference_graph(model_fns, features, params) assign_ops = [] all_var_list = tf.trainable_variables() for i in range(len(args.checkpoints)): un_init_var_list = [] name = model_cls_list[i].get_name() for v in all_var_list: if v.name.startswith(name + "_%d" % i): un_init_var_list.append(v) ops = set_variables(un_init_var_list, model_var_lists[i], name + "_%d" % i) assign_ops.extend(ops) assign_op = tf.group(*assign_ops) sess_creator = tf.train.ChiefSessionCreator( config=session_config(params) ) results = [] # Create session with tf.train.MonitoredSession(session_creator=sess_creator) as sess: # Restore variables sess.run(assign_op) while not sess.should_stop(): results.append(sess.run(predictions)) message = "Finished batch %d" % len(results) tf.logging.log(tf.logging.INFO, message) # Convert to plain text vocab = params.vocabulary["target"] outputs = [] scores = [] for result in results: outputs.append(result[0].tolist()) scores.append(result[1].tolist()) outputs = list(itertools.chain(*outputs)) scores = list(itertools.chain(*scores)) restored_inputs = [] restored_outputs = [] restored_scores = [] for index in range(len(sorted_inputs)): restored_inputs.append(sorted_inputs[sorted_keys[index]]) restored_outputs.append(outputs[sorted_keys[index]]) restored_scores.append(scores[sorted_keys[index]]) # Write to file with open(args.output, "w") as outfile: count = 0 for outputs, scores in zip(restored_outputs, restored_scores): for output, score in zip(outputs, scores): decoded = [] for idx in output: if idx == params.mapping["target"][params.eos]: break decoded.append(vocab[idx]) decoded = " ".join(decoded) if not args.verbose: outfile.write("%s\n" % decoded) break else: pattern = "%d ||| %s ||| %s ||| %f\n" source = restored_inputs[count] values = (count, source, decoded, score) outfile.write(pattern % values) count += 1