Ejemplo n.º 1
0
def train():

    # RES + RNN
    model = mdl.model_res_gru_ctc(shapes=(MAX_INPUT_LEN, FEAT_DIM, 1),
                                  bpe_classes=BPE_CLASSES,
                                  max_label_len=MAX_LABEL_LEN,
                                  cnn=CNN,
                                  raw_model=RAW_MODEL)

    parallel_model = mdl.compile(model,gpus=3,lr=0.001,
                                 loss={"ctc_loss": lambda y_true, y_pred: y_pred},
                                 loss_weights={"ctc_loss": 1.0},
                                 metrics=None)

    with tf.device("/cpu:0"):
        ctc_decode_model = mdl.sub_model(model, 'inputs', 'ctc_pred')

    class evaluation(Callback):
        def on_epoch_end(self, epoch, logs=None):
            with tf.device("/cpu:0"):
                print("============== SAVING =============")
                model.save_weights("%s/%03d.h5" % (MODEL_DIR, epoch))

                print("============ ATT EVAL ==========")
                ctc_pred = mdl.ctc_pred(ctc_decode_model, dev_data[0], input_len=ENCODER_LEN,
                                        batch_size=BATCH_SIZE)
                print("DEV-WER:", us.ctc_eval(dev_data[0]["ctc_labels"],
                                              dev_data[0]["ctc_label_len"], ctc_pred, True))

    EVL = evaluation()
    #
    parallel_model.fit_generator(generator=generator, steps_per_epoch=N_BATCHS, epochs=EPOCHS,
                                 callbacks=[ early_stopper,lr_reducer,csv_logger, EVL], initial_epoch=INIT_EPOCH,
                                 validation_data=(dev_data[0], dev_data[1]), validation_steps=N_BATCHS)
Ejemplo n.º 2
0
def test():

    model_file = "/disc1/ARNet/exp/libri/res18_gru/012.h5"

    FEATS_DEV_CLEAN = us.kaldiio.load_scp(dev_clean_file)
    FEATS_DEV_OTHER = us.kaldiio.load_scp(dev_other_file)
    FEATS_TEST_CLEAN = us.kaldiio.load_scp(test_clean_file)
    FEATS_TEST_OTHER = us.kaldiio.load_scp(test_other_file)

    dev_clean_lst = us.scp2key(us.read_lines(dev_clean_file))
    dev_other_lst = us.scp2key(us.read_lines(dev_other_file))
    test_clean_lst = us.scp2key(us.read_lines(test_clean_file))
    test_other_lst = us.scp2key(us.read_lines(test_other_file))

    dev_clean_data = us.load_ctc(dev_clean_lst, FEATS_DEV_CLEAN,
                           encoder_len=ENCODER_LEN,
                           max_input_len=MAX_INPUT_LEN,
                           max_label_len=MAX_LABEL_LEN,
                           trans_ids=us.LIBRI_TRANS_IDS)
    dev_other_data = us.load_ctc(dev_other_lst, FEATS_DEV_OTHER,
                                 encoder_len=ENCODER_LEN,
                                 max_input_len=MAX_INPUT_LEN,
                                 max_label_len=MAX_LABEL_LEN,
                                 trans_ids=us.LIBRI_TRANS_IDS)
    test_clean_data = us.load_ctc(test_clean_lst, FEATS_TEST_CLEAN,
                                 encoder_len=ENCODER_LEN,
                                 max_input_len=MAX_INPUT_LEN,
                                 max_label_len=MAX_LABEL_LEN,
                                 trans_ids=us.LIBRI_TRANS_IDS)
    test_other_data = us.load_ctc(test_other_lst, FEATS_TEST_OTHER,
                                 encoder_len=ENCODER_LEN,
                                 max_input_len=MAX_INPUT_LEN,
                                 max_label_len=MAX_LABEL_LEN,
                                 trans_ids=us.LIBRI_TRANS_IDS)

    model = mdl.model_res_gru_ctc(shapes=(MAX_INPUT_LEN, FEAT_DIM, 1),
                                         bpe_classes=BPE_CLASSES,
                                         max_label_len=MAX_LABEL_LEN,
                                         cnn=CNN,
                                         raw_model=model_file)

    ctc_decode_model = mdl.sub_model(model, 'inputs', 'ctc_pred')
    dev_clean_pred = mdl.ctc_pred(ctc_decode_model, dev_clean_data[0], input_len=ENCODER_LEN, batch_size=BATCH_SIZE)
    dev_other_pred = mdl.ctc_pred(ctc_decode_model, dev_other_data[0], input_len=ENCODER_LEN, batch_size=BATCH_SIZE)
    test_clean_pred = mdl.ctc_pred(ctc_decode_model, test_clean_data[0], input_len=ENCODER_LEN, batch_size=BATCH_SIZE)
    test_other_pred = mdl.ctc_pred(ctc_decode_model, test_other_data[0], input_len=ENCODER_LEN, batch_size=BATCH_SIZE)

    print("dev-clean-wer:",
          us.ctc_eval(dev_clean_data[0]["ctc_labels"],dev_clean_data[0]["ctc_label_len"], dev_clean_pred, True))
    print("dev-other-wer:",
          us.ctc_eval(dev_other_data[0]["ctc_labels"], dev_other_data[0]["ctc_label_len"], dev_other_pred, True))
    print("test-clean-wer:",
          us.ctc_eval(test_clean_data[0]["ctc_labels"], test_clean_data[0]["ctc_label_len"], test_clean_pred, True))
    print("test-other-wer:",
          us.ctc_eval(test_other_data[0]["ctc_labels"], test_other_data[0]["ctc_label_len"], test_other_pred, True))
Ejemplo n.º 3
0
def train():

    model = mdl.SAR_Net(input_shape=[MAX_INPUT_LEN, FEAT_DIM, 1],
                        asr_enable=True,
                        ar_enable=False,
                        res_type=RES_TYPE,
                        res_filters=RES_FILTERS,
                        hidden_dim=HIDDEN_DIM,
                        bn_dim=0,
                        encoder_rnn_num=ENCODER_RNN_NUM,
                        asr_rnn_num=ASR_RNN_NUM,
                        bpe_classes=BPE_CLASSES,
                        acc_classes=ACC_CLASSES,
                        max_label_len=MAX_LABEL_LEN,
                        mto=MANY_TO_ONE,
                        metric_loss=METRIC_LOSS,
                        margin=MARGIN,
                        raw_model=RAW_MODEL,
                        name=task)

    train_model = mdl.compile(model,gpus=2,lr=0.001,
                                 loss={"ctc_loss": lambda y_true, y_pred: y_pred},
                                 loss_weights={"ctc_loss": 1.0},
                                 metrics={})

    with tf.device("/cpu:0"):
        ctc_decode_model = mdl.sub_model(model, 'inputs', 'ctc_pred')

    class evaluation(Callback):
        def on_epoch_end(self, epoch, logs=None):
            with tf.device("/cpu:0"):
                print("============== SAVING =============")
                model.save_weights("%s/%03d.h5" % (MODEL_DIR, epoch))

                print("============ ASR EVAL ==========")
                ctc_pred = mdl.ctc_pred(ctc_decode_model, dev_data[0], input_len=ENCODER_LEN,batch_size=BATCH_SIZE)
                print("DEV-WER:", us.wer_eval(dev_data[0]["ctc_labels"],
                                              dev_data[0]["ctc_label_len"],
                                              ctc_pred, bpe=BPE, show=True))

    EVL = evaluation()
    #
    train_model.fit_generator(generator=generator, steps_per_epoch=N_BATCHS, epochs=EPOCHS,
                                 callbacks=[ early_stopper,lr_reducer,csv_logger, EVL], initial_epoch=INIT_EPOCH,
                                 validation_data=(dev_data[0], dev_data[1]),max_queue_size=20)
Ejemplo n.º 4
0
def train():

    model = mdl.model_res_gru_ctc_accent_sex(shapes=(MAX_INPUT_LEN, FEAT_DIM,
                                                     1),
                                             accent_classes=ACCENT_CLASSES,
                                             bpe_classes=BPE_CLASSES,
                                             max_label_len=MAX_LABEL_LEN,
                                             cnn=CNN,
                                             raw_model=RAW_MODEL)

    parallel_model = mdl.compile(model,
                                 gpus=4,
                                 lr=0.0005,
                                 loss={
                                     "ctc_loss": lambda y_true, y_pred: y_pred,
                                     "accent_labels":
                                     "categorical_crossentropy"
                                 },
                                 loss_weights={
                                     "ctc_loss": 0.4,
                                     "accent_labels": 0.6
                                 },
                                 metrics={"accent_labels": 'accuracy'})

    with tf.device("/cpu:0"):
        ctc_decode_model = mdl.sub_model(model, 'inputs', 'ctc_pred')

    class evaluation(Callback):
        def on_epoch_end(self, epoch, logs=None):
            with tf.device("/cpu:0"):
                print("============== SAVING =============")
                model.save_weights("%s/%03d.h5" % (MODEL_DIR, epoch))

                print("============ CTC EVAL ==========")
                ctc_pred = mdl.ctc_pred(ctc_decode_model,
                                        dev_data[0],
                                        input_len=ENCODER_LEN,
                                        batch_size=BATCH_SIZE)
                print(
                    "DEV-WER:",
                    us.ctc_eval(dev_data[0]["ctc_labels"],
                                dev_data[0]["ctc_label_len"], ctc_pred, True))

        def on_batch_begin(self, batch, logs=None):
            if batch % 300 == 0:
                with tf.device("/cpu:0"):
                    accent_pred = model.predict(dev_data[0])[1]
                    accent_acc = us.accent_acc(dev_data[1]['accent_labels'],
                                               accent_pred)
                    print("   iter:%03d dev accent_acc:" % batch, accent_acc)

    EVL = evaluation()
    #
    parallel_model.fit_generator(
        generator=generator,
        steps_per_epoch=N_BATCHS,
        epochs=EPOCHS,
        callbacks=[early_stopper, lr_reducer, csv_logger, EVL],
        initial_epoch=INIT_EPOCH,
        validation_data=(dev_data[0], dev_data[1]),
        validation_steps=N_BATCHS)
Ejemplo n.º 5
0
                              FEATS_DEV,
                              encoder_len=ENCODER_LEN,
                              max_input_len=MAX_INPUT_LEN,
                              max_label_len=MAX_LABEL_LEN,
                              trans_ids=us.AESRC_TRANS_IDS,
                              accent_classes=ACCENT_CLASSES,
                              accent_dct=us.AESRC_ACCENT,
                              accent_ids=us.AESRC_ACCENT2INT)

print("==== test ====")
model = mdl.model_ctc_accent(shapes=(MAX_INPUT_LEN, FEAT_DIM, 1),
                             accent_classes=ACCENT_CLASSES,
                             bpe_classes=BPE_CLASSES,
                             max_label_len=MAX_LABEL_LEN,
                             raw_model=RAW_MODEL)
accent_model = mdl.sub_model(model, 'inputs', 'accent_labels')

accent_pred = accent_model.predict(dev_data[0], batch_size=256)

print("Overall", us.acc(dev_data[1]['accent_labels'], accent_pred))
print("Chinese", us.acc(dev_data[1]['accent_labels'], accent_pred, 0))
print("Japanese", us.acc(dev_data[1]['accent_labels'], accent_pred, 1))
print("India", us.acc(dev_data[1]['accent_labels'], accent_pred, 2))
print("Korea", us.acc(dev_data[1]['accent_labels'], accent_pred, 3))
print("American", us.acc(dev_data[1]['accent_labels'], accent_pred, 4))
print("Britain", us.acc(dev_data[1]['accent_labels'], accent_pred, 5))
print("Portuguese", us.acc(dev_data[1]['accent_labels'], accent_pred, 6))
print("Russia", us.acc(dev_data[1]['accent_labels'], accent_pred, 7))

# test
# AESRC_FEATS = us.kaldiio.load_scp("data/aesrc_test/feats.scp")