def main(unused_argv): if args.config is None: raise Exception("Config file must be provided") json_file = open('model_configs/%s.json' % args.config).read() json_dict = json.loads(json_file) input_cls = utils.get_batched_input_class(json_dict.get("dataset", "default")) model_cls = utils.get_model_class(json_dict.get("model")) print(input_cls) hparams = utils.create_hparams(args, model_cls) hparams.hcopy_path = configs.HCOPY_PATH hparams.hcopy_config = configs.HCOPY_CONFIG_PATH hparams.input_path = os.path.join("tmp", "input.tmp") tf.reset_default_graph() graph = tf.Graph() mode = tf.estimator.ModeKeys.PREDICT with graph.as_default(): trainer = Trainer(hparams, model_cls, input_cls, mode) trainer.build_model() sess = tf.Session(graph=graph) load_model(sess, model_cls, hparams) trainer.init(sess) exit() os.system('cls') # clear screen #infer("test.wav", sess, trainer) #return while True: if input("Start recording? [Y/n]: ") != 'n': print("Recording...") record("test.wav") # infer(hparams) print("Inferring...") infer("test.wav", sess, trainer)
def eval(hparams, args, Model, BatchedInput): tf.reset_default_graph() graph = tf.Graph() mode = tf.estimator.ModeKeys.EVAL hparams.batch_size = hparams.eval_batch_size with graph.as_default(): trainer = Trainer(hparams, Model, BatchedInput, mode) trainer.build_model() sess = tf.Session(graph=graph) load_model(sess, Model, hparams) trainer.init(sess) dlgids = [] lers = [] pbar = tqdm(total=trainer.data_size, ncols=100) pbar.set_description("Eval") fo = open(os.path.join(hparams.summaries_dir, "eval_ret.txt"), "w") utils.prepare_output_path(hparams) errs = {} ref_lens = {} while True: try: ids, ground_truth_labels, predicted_labels, ground_truth_len, predicted_len = trainer.eval( sess) utils.write_log(hparams, [str(ground_truth_labels)]) decode_fns = trainer.test_model.get_decode_fns() # dlgids += list([str(id).split('/')[-2] for id in ids]) metrics = (args.metrics or hparams.metrics).split(',') for acc_id, (gt_labels, p_labels, gt_len, p_len) in \ enumerate(zip(ground_truth_labels, predicted_labels, ground_truth_len, predicted_len)): if acc_id not in lers: lers[acc_id] = [] for i in range(len(gt_labels)): if acc_id == 1 and (hparams.model == "da_attention_seg"): ler, str_original, str_decoded = ops_utils.joint_evaluate( hparams, ground_truth_labels[0][i], predicted_labels[0][i], ground_truth_labels[1][i], predicted_labels[1][i], decode_fns[acc_id], ) else: err, ref_len, str_original, str_decoded = ops_utils.evaluate( gt_labels[i], # gt_labels[i][:gt_len[i]], p_labels[i], # p_labels[i][:p_len[i]], decode_fns[acc_id], metrics[acc_id], acc_id) if err is not None: errs[acc_id].append(err) ref_lens[acc_id].append(ref_len) if hparams.input_unit == "word": str_original = ' '.join(str_original) str_decoded = ' '.join(str_decoded) elif hparams.input_unit == "char": str_original = ''.join(str_original).replace( '_', ' ') str_decoded = ''.join(str_decoded).replace( '_', ' ') tqdm.write( "\nGT: %s\nPR: %s\nLER: %.3f\n" % (str_original, str_decoded, err / ref_len)) #tqdm.write(str(p_labels[i])) #tqdm.write("%d %d" % (gt_len[i], p_len[i])) meta = tf.SummaryMetadata() meta.plugin_data.plugin_name = "text" # update pbar progress and postfix pbar.update(trainer.batch_size) bar_pf = {} for acc_id in range(len(ground_truth_labels)): bar_pf["er" + str(acc_id)] = "%2.2f" % ( sum(errs[acc_id]) / sum(ref_lens[acc_id]) * 100) pbar.set_postfix(bar_pf) except tf.errors.OutOfRangeError: break # acc_by_ids = {} # for i, id in enumerate(dlgids): # if id not in acc_by_ids: acc_by_ids[id] = [] # acc_by_ids[id].append(lers[0][i]) # print("\n\n----- Statistics -----") # for id, ls in acc_by_ids.items(): # print("%s\t%2.2f" % (id, sum(ls) / len(ls))) # fo.write("LER: %2.2f" % (sum(lers) / len(lers) * 100)) # print(len(lers[0])) fo.close()