def main(unused_argv):
    if args.config is None:
        raise Exception("Config file must be provided")

    json_file = open('model_configs/%s.json' % args.config).read()
    json_dict = json.loads(json_file)
    input_cls = utils.get_batched_input_class(json_dict.get("dataset", "default"))
    model_cls = utils.get_model_class(json_dict.get("model"))
    print(input_cls)

    hparams = utils.create_hparams(args, model_cls)
    hparams.hcopy_path = configs.HCOPY_PATH
    hparams.hcopy_config = configs.HCOPY_CONFIG_PATH

    hparams.input_path = os.path.join("tmp", "input.tmp")

    tf.reset_default_graph()
    graph = tf.Graph()
    mode = tf.estimator.ModeKeys.PREDICT

    with graph.as_default():
        trainer = Trainer(hparams, model_cls, input_cls, mode)
        trainer.build_model()

        sess = tf.Session(graph=graph)
        load_model(sess, model_cls, hparams)
        trainer.init(sess)

    exit()
    os.system('cls')  # clear screen
    #infer("test.wav", sess, trainer)
    #return
    while True:
        if input("Start recording? [Y/n]: ") != 'n':
            print("Recording...")
            record("test.wav")
            # infer(hparams)
            print("Inferring...")
            infer("test.wav", sess, trainer)
Example #2
0
def eval(hparams, args, Model, BatchedInput):
    tf.reset_default_graph()
    graph = tf.Graph()
    mode = tf.estimator.ModeKeys.EVAL
    hparams.batch_size = hparams.eval_batch_size

    with graph.as_default():
        trainer = Trainer(hparams, Model, BatchedInput, mode)
        trainer.build_model()

        sess = tf.Session(graph=graph)
        load_model(sess, Model, hparams)
        trainer.init(sess)

        dlgids = []
        lers = []

        pbar = tqdm(total=trainer.data_size, ncols=100)
        pbar.set_description("Eval")
        fo = open(os.path.join(hparams.summaries_dir, "eval_ret.txt"), "w")
        utils.prepare_output_path(hparams)
        errs = {}
        ref_lens = {}
        while True:
            try:
                ids, ground_truth_labels, predicted_labels, ground_truth_len, predicted_len = trainer.eval(
                    sess)
                utils.write_log(hparams, [str(ground_truth_labels)])

                decode_fns = trainer.test_model.get_decode_fns()
                # dlgids += list([str(id).split('/')[-2] for id in ids])
                metrics = (args.metrics or hparams.metrics).split(',')
                for acc_id, (gt_labels, p_labels, gt_len, p_len) in \
                        enumerate(zip(ground_truth_labels, predicted_labels,
                                      ground_truth_len, predicted_len)):
                    if acc_id not in lers: lers[acc_id] = []

                    for i in range(len(gt_labels)):
                        if acc_id == 1 and (hparams.model
                                            == "da_attention_seg"):
                            ler, str_original, str_decoded = ops_utils.joint_evaluate(
                                hparams,
                                ground_truth_labels[0][i],
                                predicted_labels[0][i],
                                ground_truth_labels[1][i],
                                predicted_labels[1][i],
                                decode_fns[acc_id],
                            )
                        else:
                            err, ref_len, str_original, str_decoded = ops_utils.evaluate(
                                gt_labels[i],
                                # gt_labels[i][:gt_len[i]],
                                p_labels[i],
                                # p_labels[i][:p_len[i]],
                                decode_fns[acc_id],
                                metrics[acc_id],
                                acc_id)

                        if err is not None:
                            errs[acc_id].append(err)
                            ref_lens[acc_id].append(ref_len)

                            if hparams.input_unit == "word":
                                str_original = ' '.join(str_original)
                                str_decoded = ' '.join(str_decoded)
                            elif hparams.input_unit == "char":
                                str_original = ''.join(str_original).replace(
                                    '_', ' ')
                                str_decoded = ''.join(str_decoded).replace(
                                    '_', ' ')

                            tqdm.write(
                                "\nGT: %s\nPR: %s\nLER: %.3f\n" %
                                (str_original, str_decoded, err / ref_len))
                            #tqdm.write(str(p_labels[i]))
                            #tqdm.write("%d %d" % (gt_len[i], p_len[i]))

                            meta = tf.SummaryMetadata()
                            meta.plugin_data.plugin_name = "text"

                # update pbar progress and postfix
                pbar.update(trainer.batch_size)
                bar_pf = {}
                for acc_id in range(len(ground_truth_labels)):
                    bar_pf["er" + str(acc_id)] = "%2.2f" % (
                        sum(errs[acc_id]) / sum(ref_lens[acc_id]) * 100)
                pbar.set_postfix(bar_pf)
            except tf.errors.OutOfRangeError:
                break

    # acc_by_ids = {}
    # for i, id in enumerate(dlgids):
    #    if id not in acc_by_ids: acc_by_ids[id] = []
    #    acc_by_ids[id].append(lers[0][i])

    # print("\n\n----- Statistics -----")
    # for id, ls in acc_by_ids.items():
    #     print("%s\t%2.2f" % (id, sum(ls) / len(ls)))

    # fo.write("LER: %2.2f" % (sum(lers) / len(lers) * 100))
    # print(len(lers[0]))
    fo.close()