Esempio n. 1
0
def split_save(capacity=10000):
    fw = open(args.dirs.train.tfdata / '0.csv', 'w')
    with open(args.dirs.train.data) as f:
        for num, line in enumerate(f):
            if num % capacity == 0:
                idx_file = num // capacity
                print('package file ', idx_file)
                try:
                    fw.close()
                    fw = open(
                        args.dirs.train.tfdata / (str(idx_file) + '.csv'), 'w')
                except:
                    pass
            fw.write(line)
    print('processed {} utts.'.format(num + 1))
    fw.close()

    for i in Path(args.dirs.train.tfdata).glob('*.csv'):
        print('converting {}.csv to record'.format(i.name))
        dataset_train = ASR_align_DataSet(file=[i],
                                          args=args,
                                          _shuffle=False,
                                          transform=True)
        tfdata_train = TFData(dataset=dataset_train,
                              dataAttr=['feature', 'label', 'align'],
                              dir_save=args.dirs.train.tfdata,
                              args=args)

        tfdata_train.save(i.name.split('.')[0])
Esempio n. 2
0
def Decode(save_file):
    dataset = ASR_align_DataSet(trans_file=args.dirs.train.trans,
                                align_file=None,
                                uttid2wav=args.dirs.train.wav_scp,
                                feat_len_file=args.dirs.train.feat_len,
                                args=args,
                                _shuffle=False,
                                transform=True)
    dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                    align_file=args.dirs.dev.align,
                                    uttid2wav=args.dirs.dev.wav_scp,
                                    feat_len_file=args.dirs.dev.feat_len,
                                    args=args,
                                    _shuffle=False,
                                    transform=True)

    feature_dev = TFData(dataset=dataset_dev,
                         dir_save=args.dirs.dev.tfdata,
                         args=args).read()
    feature_dev = feature_dev.padded_batch(args.batch_size,
                                           ((), [None, args.dim_input]))

    G = PhoneClassifier(args)
    G.summary()

    optimizer_G = tf.keras.optimizers.Adam(1e-4)
    ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dirs.checkpoint,
                                              max_to_keep=1)
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('checkpoint {} restored!!'.format(ckpt_manager.latest_checkpoint))
    fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G)
    decode(dataset, G, args.idx2token, 'output/' + save_file)
Esempio n. 3
0
def Decode(save_file):
    dataset_dev = ASR_align_ArkDataSet(scp_file=args.dirs.dev.scp,
                                       trans_file=args.dirs.dev.trans,
                                       align_file=None,
                                       feat_len_file=None,
                                       args=args,
                                       _shuffle=False,
                                       transform=False)
    feature_dev = TFData(dataset=dataset_dev,
                         dir_save=args.dirs.dev.tfdata,
                         args=args).read()
    feature_dev = feature_dev.padded_batch(args.batch_size,
                                           ((), [None, args.dim_input]))

    encoder = Encoder(args)
    decoder = Decoder(args)
    D = CLM(args)
    encoder.summary()
    decoder.summary()
    D.summary()

    ckpt_G = tf.train.Checkpoint(encoder=encoder, decoder=decoder)
    _ckpt_manager = tf.train.CheckpointManager(ckpt_G,
                                               args.dirs.checkpoint,
                                               max_to_keep=1)
    ckpt_G.restore(_ckpt_manager.latest_checkpoint)
    print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
    cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, encoder,
                   decoder)
    print('PER:{:.3f}'.format(cer))
Esempio n. 4
0
def Decode(save_file):
    dataset_dev = ASR_align_ArkDataSet(scp_file=args.dirs.dev.scp,
                                       trans_file=args.dirs.dev.trans,
                                       align_file=None,
                                       feat_len_file=None,
                                       args=args,
                                       _shuffle=False,
                                       transform=False)
    feature_dev = TFData(dataset=dataset_dev,
                         dir_save=args.dirs.dev.tfdata,
                         args=args).read()
    feature_dev = feature_dev.padded_batch(args.batch_size,
                                           ((), [None, args.dim_input]))

    model = Model(args)
    model.summary()

    optimizer = tf.keras.optimizers.Adam(1e-4)
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)

    _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                               args.dirs.checkpoint,
                                               max_to_keep=1)
    ckpt.restore(_ckpt_manager.latest_checkpoint)
    print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
    cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, model)
    print('PER:{:.3f}'.format(cer))
def Decode(save_file):
    # dataset = ASR_align_DataSet(
    #     trans_file=args.dirs.train.trans,
    #     align_file=None,
    #     uttid2wav=args.dirs.train.wav_scp,
    #     feat_len_file=args.dirs.train.feat_len,
    #     args=args,
    #     _shuffle=False,
    #     transform=True)
    dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                    align_file=args.dirs.dev.align,
                                    uttid2wav=args.dirs.dev.wav_scp,
                                    feat_len_file=args.dirs.dev.feat_len,
                                    args=args,
                                    _shuffle=False,
                                    transform=True)
    feature_dev = TFData(dataset=dataset_dev,
                         dir_save=args.dirs.dev.tfdata,
                         args=args).read()
    feature_dev = feature_dev.padded_batch(args.batch_size,
                                           ((), [None, args.dim_input]))

    model = PhoneClassifier(args)
    model.summary()

    optimizer_G = tf.keras.optimizers.Adam(1e-4)
    ckpt = tf.train.Checkpoint(model=model, optimizer_G=optimizer_G)

    _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                               args.dirs.checkpoint,
                                               max_to_keep=1)
    ckpt.restore(_ckpt_manager.latest_checkpoint)
    print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
    # fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, model, beam_size=0, with_stamp=True)
    # fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, model, beam_size=0, with_stamp=False)
    decode_outs = np.zeros((300), dtype=np.int32)
    wfst = WFST_Decoder(decode_outs=decode_outs,
                        fcdll="../WFST/libctc_wfst_lib.so",
                        fcfg="wfst/timit.json")
    fer, cer = evaluate(feature_dev,
                        dataset_dev,
                        args.data.dev_size,
                        model,
                        wfst=wfst)
    print('FER:{:.3f}\t WFST PER:{:.3f}'.format(fer, cer))
    fer, cer = evaluate(feature_dev,
                        dataset_dev,
                        args.data.dev_size,
                        model,
                        beam_size=0,
                        with_stamp=False)
    print('FER:{:.3f}\t PER:{:.3f}'.format(fer, cer))
Esempio n. 6
0
def main():
    dataset_train = ASR_align_DataSet(file=[args.dirs.train.data],
                                      args=args,
                                      _shuffle=False,
                                      transform=True)
    dataset_dev = ASR_align_DataSet(file=[args.dirs.dev.data],
                                    args=args,
                                    _shuffle=False,
                                    transform=True)
    tfdata_train = TFData(dataset=dataset_train,
                          dataAttr=['feature', 'label', 'align'],
                          dir_save=args.dirs.train.tfdata,
                          args=args)
    tfdata_dev = TFData(dataset=dataset_dev,
                        dataAttr=['feature', 'label', 'align'],
                        dir_save=args.dirs.dev.tfdata,
                        args=args)
    tfdata_train.save('0')
    tfdata_dev.save('0')
    # tfdata_train.get_bucket_size(100, True)
    # split_save()
    # for sample in tfdata_dev.read():
    # # for sample in dataset_train:
    #     # print(sample['feature'].shape)
    #     # print(sample['label'])
    #     print(sample[0].shape)
    #     import pdb; pdb.set_trace()
    dataset_train.get_dataset_ngram(n=args.data.ngram,
                                    k=10000,
                                    savefile=args.dirs.ngram)
Esempio n. 7
0
def main():
    # dataset_train = ASR_align_ArkDataSet(
    #     scp_file=args.dirs.train.scp,
    #     trans_file=args.dirs.train.trans,
    #     align_file=None,
    #     feat_len_file=None,
    #     args=args,
    #     _shuffle=True,
    #     transform=False)
    # dataset_dev = ASR_align_ArkDataSet(
    #     scp_file=args.dirs.dev.scp,
    #     trans_file=args.dirs.dev.trans,
    #     align_file=None,
    #     feat_len_file=None,
    #     args=args,
    #     _shuffle=False,
    #     transform=False)
    # dataset_untrain = ASR_align_ArkDataSet(
    #     scp_file=args.dirs.untrain.scp,
    #     trans_file=None,
    #     align_file=None,
    #     feat_len_file=None,
    #     args=args,
    #     _shuffle=True,
    #     transform=False)
    dataset_train = ASR_align_DataSet(trans_file=args.dirs.train.trans,
                                      align_file=args.dirs.train.align,
                                      uttid2wav=args.dirs.train.wav_scp,
                                      feat_len_file=args.dirs.train.feat_len,
                                      args=args,
                                      _shuffle=False,
                                      transform=True)
    dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                    uttid2wav=args.dirs.dev.wav_scp,
                                    align_file=args.dirs.dev.align,
                                    feat_len_file=args.dirs.dev.feat_len,
                                    args=args,
                                    _shuffle=False,
                                    transform=True)
    feature_train = TFData(dataset=dataset_train,
                           dir_save=args.dirs.train.tfdata,
                           args=args)
    # feature_untrain = TFData(dataset=dataset_untrain,
    #                 dir_save=args.dirs.untrain.tfdata,
    #                 args=args)
    # feature_train_supervise = TFData(dataset=dataset_train_supervise,
    #                 dir_save=args.dirs.train_supervise.tfdata,
    #                 args=args)
    feature_dev = TFData(dataset=dataset_dev,
                         dir_save=args.dirs.dev.tfdata,
                         args=args)
    feature_train.split_save(capacity=100000)
    feature_dev.split_save(capacity=100000)
def Decode(save_file):
    dataset = ASR_align_DataSet(trans_file=None,
                                align_file=None,
                                uttid2wav=args.dirs.train.wav_scp,
                                feat_len_file=args.dirs.train.feat_len,
                                args=args,
                                _shuffle=False,
                                transform=True)
    dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                    align_file=args.dirs.dev.align,
                                    uttid2wav=args.dirs.dev.wav_scp,
                                    feat_len_file=args.dirs.dev.feat_len,
                                    args=args,
                                    _shuffle=False,
                                    transform=True)

    feature_train = TFData(dataset=dataset,
                           dir_save=args.dirs.train.tfdata,
                           args=args).read()
    feature_dev = TFData(dataset=dataset_dev,
                         dir_save=args.dirs.dev.tfdata,
                         args=args).read()

    feature_train = feature_train.cache().repeat().shuffle(500).padded_batch(
        args.batch_size, ((), [None, args.dim_input])).prefetch(buffer_size=5)
    feature_dev = feature_dev.padded_batch(args.batch_size,
                                           ((), [None, args.dim_input]))

    # model = GRNN(args)
    model = GRNN_Cell(args)
    model.summary()

    optimizer = tf.keras.optimizers.Adam(1e-4)
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dirs.checkpoint,
                                              max_to_keep=1)
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('checkpoint {} restored!!'.format(ckpt_manager.latest_checkpoint))

    # plot_align(model, dataset_dev, 0)

    for batch in feature_train:
        uttids, x = batch
        batch_size = len(x)
        h1 = h2 = tf.zeros([batch_size, args.model.num_hidden_rnn])
        for x_t in tf.split(x, x.shape[1], axis=1):

            x_fc, x_cell, h1, h2 = model([x_t, h1, h2], training=False)
            z, r = get_GRU_activation(model.layers[3],
                                      cell_inputs=x_fc[:, 0, :],
                                      hiddens=h1)
def main():
    dataset_train = ASR_align_ArkDataSet(scp_file=args.dirs.train.scp,
                                         trans_file=args.dirs.train.trans,
                                         align_file=None,
                                         feat_len_file=None,
                                         args=args,
                                         _shuffle=True,
                                         transform=False)
    dataset_dev = ASR_align_ArkDataSet(scp_file=args.dirs.dev.scp,
                                       trans_file=args.dirs.dev.trans,
                                       align_file=None,
                                       feat_len_file=None,
                                       args=args,
                                       _shuffle=False,
                                       transform=False)
    dataset_untrain = ASR_align_ArkDataSet(scp_file=args.dirs.untrain.scp,
                                           trans_file=None,
                                           align_file=None,
                                           feat_len_file=None,
                                           args=args,
                                           _shuffle=True,
                                           transform=False)
    feature_train = TFData(dataset=dataset_train,
                           dir_save=args.dirs.train.tfdata,
                           args=args)
    feature_untrain = TFData(dataset=dataset_untrain,
                             dir_save=args.dirs.untrain.tfdata,
                             args=args)
    # feature_train_supervise = TFData(dataset=dataset_train_supervise,
    #                 dir_save=args.dirs.train_supervise.tfdata,
    #                 args=args)
    feature_dev = TFData(dataset=dataset_dev,
                         dir_save=args.dirs.dev.tfdata,
                         args=args)
    # feature_train.split_save(capacity=100000)
    # feature_dev.split_save(capacity=100000)
    feature_untrain.split_save(capacity=100000)
Esempio n. 10
0
def train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_DataSet(uttid2wav=args.dirs.train.wav_scp,
                                          trans_file=None,
                                          align_file=None,
                                          feat_len_file=None,
                                          args=args,
                                          _shuffle=True,
                                          transform=True)
        dataset_dev = ASR_align_DataSet(uttid2wav=args.dirs.dev.wav_scp,
                                        trans_file=None,
                                        align_file=None,
                                        feat_len_file=None,
                                        args=args,
                                        _shuffle=False,
                                        transform=True)
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()

        feature_train = feature_train.cache().repeat().shuffle(
            500).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5)
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

    # create model paremeters
    model = GRNN(args)
    model.summary()
    optimizer = tf.keras.optimizers.Adam(args.opti.lr, beta_1=0.5, beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=10)
    step = 0

    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                   args.dirs.checkpoint,
                                                   max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    for batch in feature_train:
        start = time()

        uttids, x = batch
        loss = train_step(x, model, optimizer)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            print('loss: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% step: {}'.
                  format(loss, x.shape,
                         time() - start, progress * 100.0, step))
            with writer.as_default():
                tf.summary.scalar("losses/loss", loss, step=step)
        if step % args.dev_step == 0:
            loss_dev = evaluate(feature_dev, model)
            with writer.as_default():
                tf.summary.scalar("losses/loss_dev", loss_dev, step=step)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
def Train():
    dataset_train = ASR_align_DataSet(trans_file=args.dirs.train.trans,
                                      align_file=args.dirs.train.align,
                                      uttid2wav=args.dirs.train.wav_scp,
                                      feat_len_file=args.dirs.train.feat_len,
                                      args=args,
                                      _shuffle=True,
                                      transform=True)
    dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                    align_file=args.dirs.dev.align,
                                    uttid2wav=args.dirs.dev.wav_scp,
                                    feat_len_file=args.dirs.dev.feat_len,
                                    args=args,
                                    _shuffle=False,
                                    transform=True)
    with tf.device("/cpu:0"):
        # wav data
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()
        if args.num_supervised:
            dataset_train_supervise = ASR_align_DataSet(
                trans_file=args.dirs.train_supervise.trans,
                align_file=args.dirs.train_supervise.align,
                uttid2wav=args.dirs.train_supervise.wav_scp,
                feat_len_file=args.dirs.train_supervise.feat_len,
                args=args,
                _shuffle=False,
                transform=True)
            feature_train_supervise = TFData(
                dataset=dataset_train_supervise,
                dir_save=args.dirs.train_supervise.tfdata,
                args=args).read()
            supervise_uttids, supervise_x = next(iter(feature_train_supervise.take(args.num_supervised).\
                padded_batch(args.num_supervised, ((), [None, args.dim_input]))))
            supervise_aligns = dataset_train_supervise.get_attrs(
                'align', supervise_uttids.numpy())
            # supervise_bounds = dataset_train_supervise.get_attrs('bounds', supervise_uttids.numpy())

        iter_feature_train = iter(
            feature_train.repeat().shuffle(500).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

    # create model paremeters
    model = PhoneClassifier(args)
    model.summary()
    optimizer_G = tf.keras.optimizers.Adam(args.opti.lr,
                                           beta_1=0.5,
                                           beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(model=model, optimizer_G=optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                   args.dirs.checkpoint,
                                                   max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()

    while step < 99999999:
        start = time()

        if args.num_supervised:
            x = supervise_x
            loss_supervise = train_G_supervised(supervise_x, supervise_aligns,
                                                model, optimizer_G,
                                                args.dim_output)
            # loss_supervise, bounds_loss = train_G_bounds_supervised(
            #     x, supervise_bounds, supervise_aligns, model, optimizer_G, args.dim_output)
        else:
            uttids, x = next(iter_feature_train)
            aligns = dataset_train.get_attrs('align', uttids.numpy())
            # trans = dataset_train.get_attrs('trans', uttids.numpy())
            loss_supervise = train_G_supervised(x, aligns, model, optimizer_G,
                                                args.dim_output)
            # loss_supervise = train_G_TBTT_supervised(x, aligns, model, optimizer_G, args.dim_output)
            # bounds = dataset_train.get_attrs('bounds', uttids.numpy())
            # loss_supervise, bounds_loss = train_G_bounds_supervised(x, bounds, aligns, model, optimizer_G, args.dim_output)
            # loss_supervise = train_G_CTC_supervised(x, trans, model, optimizer_G)

        if step % 10 == 0:
            print('loss_supervise: {:.3f}\tbatch: {}\tused: {:.3f}\tstep: {}'.
                  format(loss_supervise, x.shape,
                         time() - start, step))
            # print('loss_supervise: {:.3f}\tloss_bounds: {:.3f}\tbatch: {}\tused: {:.3f}\tstep: {}'.format(
            #        loss_supervise, bounds_loss, x.shape, time()-start, step))
            with writer.as_default():
                tf.summary.scalar("costs/loss_supervise",
                                  loss_supervise,
                                  step=step)
        if step % args.dev_step == 0:
            fer, cer_0 = evaluate(feature_dev,
                                  dataset_dev,
                                  args.data.dev_size,
                                  model,
                                  beam_size=0,
                                  with_stamp=True)
            fer, cer = evaluate(feature_dev,
                                dataset_dev,
                                args.data.dev_size,
                                model,
                                beam_size=0,
                                with_stamp=False)
            with writer.as_default():
                tf.summary.scalar("performance/fer", fer, step=step)
                tf.summary.scalar("performance/cer_0", cer_0, step=step)
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], model)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Esempio n. 12
0
def Train():
    args.data.untrain_size = TFData.read_tfdata_info(
        args.dirs.untrain.tfdata)['size_dataset']
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_ArkDataSet(scp_file=args.dirs.train.scp,
                                             trans_file=args.dirs.train.trans,
                                             align_file=None,
                                             feat_len_file=None,
                                             args=args,
                                             _shuffle=False,
                                             transform=False)
        dataset_untrain = ASR_align_ArkDataSet(scp_file=args.dirs.untrain.scp,
                                               trans_file=None,
                                               align_file=None,
                                               feat_len_file=None,
                                               args=args,
                                               _shuffle=False,
                                               transform=False)
        dataset_dev = ASR_align_ArkDataSet(scp_file=args.dirs.dev.scp,
                                           trans_file=args.dirs.dev.trans,
                                           align_file=None,
                                           feat_len_file=None,
                                           args=args,
                                           _shuffle=False,
                                           transform=False)
        # wav data
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_unsupervise = TFData(dataset=dataset_untrain,
                                     dir_save=args.dirs.untrain.tfdata,
                                     args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()
        bucket = tf.data.experimental.bucket_by_sequence_length(
            element_length_func=lambda uttid, x: tf.shape(x)[0],
            bucket_boundaries=args.list_bucket_boundaries,
            bucket_batch_sizes=args.list_batch_size,
            padded_shapes=((), [None, args.dim_input]))
        iter_feature_train = iter(
            feature_train.repeat().shuffle(100).apply(bucket).prefetch(
                buffer_size=5))
        # iter_feature_unsupervise = iter(feature_unsupervise.repeat().shuffle(100).apply(bucket).prefetch(buffer_size=5))
        # iter_feature_train = iter(feature_train.repeat().shuffle(100).padded_batch(args.batch_size,
        #         ((), [None, args.dim_input])).prefetch(buffer_size=5))
        iter_feature_unsupervise = iter(
            feature_unsupervise.repeat().shuffle(100).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        # feature_dev = feature_dev.apply(bucket).prefetch(buffer_size=5)
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

        dataset_text = TextDataSet(list_files=[args.dirs.lm.data],
                                   args=args,
                                   _shuffle=True)
        tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                      (tf.TensorShape([None])))
        iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map(
            lambda x: x[:args.model.D.max_label_len]).padded_batch(
                args.text_batch_size,
                ([args.model.D.max_label_len])).prefetch(buffer_size=5))

    # create model paremeters
    encoder = Encoder(args)
    decoder = Decoder(args)
    D = CLM(args)
    encoder.summary()
    decoder.summary()
    D.summary()
    optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9)
    optimizer_D = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt_G = tf.train.Checkpoint(encoder=encoder, decoder=decoder)
    ckpt_manager = tf.train.CheckpointManager(ckpt_G,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0

    if args.dirs.checkpoint_G:
        _ckpt_manager = tf.train.CheckpointManager(ckpt_G,
                                                   args.dirs.checkpoint_G,
                                                   max_to_keep=1)
        ckpt_G.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint_G {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        # cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, encoder, decoder)
        # with writer.as_default():
        #     tf.summary.scalar("performance/cer", cer, step=step)

    start_time = datetime.now()
    num_processed = 0
    while step < 99999999:
        start = time()

        # supervise training
        uttids, x = next(iter_feature_train)
        trans = dataset_train.get_attrs('trans', uttids.numpy())
        loss_supervise = train_CTC_supervised(x, trans, encoder, decoder,
                                              optimizer)

        # unsupervise training
        text = next(iter_text)
        _, un_x = next(iter_feature_unsupervise)
        # loss_G = train_G(un_x, encoder, decoder, D, optimizer, args.model.D.max_label_len)
        loss_G = train_G(un_x, encoder, decoder, D, optimizer,
                         args.model.D.max_label_len)
        loss_D = train_D(un_x, text, encoder, decoder, D, optimizer_D,
                         args.lambda_gp, args.model.D.max_label_len)

        num_processed += len(un_x)
        progress = num_processed / args.data.untrain_size

        if step % 10 == 0:
            print(
                'loss_supervise: {:.3f}\tloss_G: {:.3f}\tloss_D: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% step: {}'
                .format(loss_supervise, loss_G, loss_D, un_x.shape,
                        time() - start, progress * 100, step))
            with writer.as_default():
                tf.summary.scalar("costs/loss_supervise",
                                  loss_supervise,
                                  step=step)
        if step % args.dev_step == args.dev_step - 1:
            cer = evaluate(feature_dev, dataset_dev, args.data.dev_size,
                           encoder, decoder)
            with writer.as_default():
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], encoder, decoder)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Esempio n. 13
0
def train():
    # load external LM
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_DataSet(
            trans_file=args.dirs.train.trans,
            align_file=args.dirs.train.align,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=True,
            transform=True)
        dataset_dev = ASR_align_DataSet(
            trans_file=args.dirs.dev.trans,
            align_file=args.dirs.dev.align,
            uttid2wav=args.dirs.dev.wav_scp,
            feat_len_file=args.dirs.dev.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        dataset_train_supervise = ASR_align_DataSet(
            trans_file=args.dirs.train_supervise.trans,
            align_file=args.dirs.train_supervise.align,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        feature_train = TFData(dataset=dataset_train,
                        dir_save=args.dirs.train.tfdata,
                        args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                        dir_save=args.dirs.dev.tfdata,
                        args=args).read()
        supervise_uttids, supervise_x = next(iter(feature_train.take(args.num_supervised).\
            padded_batch(args.num_supervised, ((), [None, args.dim_input]))))
        supervise_aligns = dataset_train_supervise.get_attrs('align', supervise_uttids.numpy())

        iter_feature_train = iter(feature_train.cache().repeat().shuffle(500).padded_batch(args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size, ((), [None, args.dim_input]))

    # get dataset ngram
    ngram_py, total_num = read_ngram(args.data.k, args.dirs.ngram, args.token2idx, type='list')
    kernel, py = ngram2kernel(ngram_py, args)

    # create model paremeters
    G = PhoneClassifier(args)
    compute_p_ngram = P_Ngram(kernel, args)
    G.summary()
    compute_p_ngram.summary()

    # build optimizer
    if args.opti.type == 'adam':
        optimizer = tf.keras.optimizers.Adam(args.opti.lr, beta_1=0.5, beta_2=0.9)
    elif args.opti.type == 'sgd':
        optimizer = tf.keras.optimizers.SGD(lr=args.opti.lr, momentum=0.9, decay=0.98)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=5)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt, args.dirs.checkpoint, max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        uttids, x = next(iter_feature_train)
        stamps = dataset_train.get_attrs('stamps', uttids.numpy())

        loss_EODM, loss_fs = train_step(x, stamps, py, G, compute_p_ngram, optimizer, args.lambda_fs)
        # loss_EODM = loss_fs = 0
        loss_supervise = train_G_supervised(supervise_x, supervise_aligns, G, optimizer, args.dim_output, args.lambda_supervision)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            print('EODM loss: {:.2f}\tloss_fs: {:.3f} * {}\tloss_supervise: {:.3f} * {}\tbatch: {} time: {:.2f} s {:.3f}% step: {}'.format(
                   loss_EODM, loss_fs, args.lambda_fs, loss_supervise, args.lambda_supervision, x.shape, time()-start, progress*100.0, step))
            with writer.as_default():
                tf.summary.scalar("costs/loss_EODM", loss_EODM, step=step)
                tf.summary.scalar("costs/loss_fs", loss_fs, step=step)
                tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step)

        if step % args.dev_step == 0:
            fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G)
            with writer.as_default():
                tf.summary.scalar("performance/fer", fer, step=step)
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], G)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
Esempio n. 14
0
def train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_DataSet(
            trans_file=args.dirs.train.trans,
            align_file=args.dirs.train.align,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=True,
            transform=True)
        dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                        align_file=args.dirs.dev.align,
                                        uttid2wav=args.dirs.dev.wav_scp,
                                        feat_len_file=args.dirs.dev.feat_len,
                                        args=args,
                                        _shuffle=False,
                                        transform=True)
        dataset_train_supervise = ASR_align_DataSet(
            trans_file=args.dirs.train_supervise.trans,
            align_file=args.dirs.train_supervise.align,
            uttid2wav=args.dirs.train_supervise.wav_scp,
            feat_len_file=args.dirs.train_supervise.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        feature_train_supervise = TFData(
            dataset=dataset_train_supervise,
            dir_save=args.dirs.train_supervise.tfdata,
            args=args).read()
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()
        supervise_uttids, supervise_x = next(iter(feature_train_supervise.take(args.num_supervised).\
            padded_batch(args.num_supervised, ((), [None, args.dim_input]))))
        supervise_aligns = dataset_train_supervise.get_attrs(
            'align', supervise_uttids.numpy())
        supervise_bounds = dataset_train_supervise.get_attrs(
            'bounds', supervise_uttids.numpy())

        iter_feature_train = iter(
            feature_train.cache().repeat().shuffle(500).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

        dataset_text = TextDataSet(list_files=[args.dirs.lm.data],
                                   args=args,
                                   _shuffle=True)
        tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                      (tf.TensorShape([None])))
        iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map(
            lambda x: x[:args.model.D.max_label_len]).padded_batch(
                args.batch_size,
                ([args.model.D.max_label_len])).prefetch(buffer_size=5))

    # create model paremeters
    G = PhoneClassifier(args)
    D = PhoneDiscriminator3(args)
    G.summary()
    D.summary()

    optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr,
                                           beta_1=0.5,
                                           beta_2=0.9)
    optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr,
                                           beta_1=0.5,
                                           beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                   args.dirs.checkpoint,
                                                   max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        for _ in range(args.opti.D_G_rate):
            uttids, x = next(iter_feature_train)
            stamps = dataset_train.get_attrs('stamps', uttids.numpy())
            text = next(iter_text)
            P_Real = tf.one_hot(text, args.dim_output)
            cost_D, gp = train_D(x, stamps, P_Real, text > 0, G, D,
                                 optimizer_D, args.lambda_gp,
                                 args.model.D.max_label_len)
            # cost_D, gp = train_D(x, P_Real, text>0, G, D, optimizer_D,
            #                      args.lambda_gp, args.model.G.len_seq, args.model.D.max_label_len)

        uttids, x = next(iter_feature_train)
        stamps = dataset_train.get_attrs('stamps', uttids.numpy())
        cost_G, fs = train_G(x, stamps, G, D, optimizer_G, args.lambda_fs)
        # cost_G, fs = train_G(x, G, D, optimizer_G,
        #                      args.lambda_fs, args.model.G.len_seq, args.model.D.max_label_len)

        loss_supervise = train_G_supervised(supervise_x, supervise_aligns, G,
                                            optimizer_G, args.dim_output,
                                            args.lambda_supervision)
        # loss_supervise, bounds_loss = train_G_bounds_supervised(
        #     supervise_x, supervise_bounds, supervise_aligns, G, optimizer_G, args.dim_output)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            print(
                'cost_G: {:.3f}|{:.3f}\tcost_D: {:.3f}|{:.3f}\tloss_supervise: {:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'
                .format(cost_G, fs, cost_D, gp, loss_supervise, x.shape,
                        text.shape,
                        time() - start, progress * 100.0, step))
            with writer.as_default():
                tf.summary.scalar("costs/cost_G", cost_G, step=step)
                tf.summary.scalar("costs/cost_D", cost_D, step=step)
                tf.summary.scalar("costs/gp", gp, step=step)
                tf.summary.scalar("costs/fs", fs, step=step)
                tf.summary.scalar("costs/loss_supervise",
                                  loss_supervise,
                                  step=step)
        if step % args.dev_step == 0:
            # fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G)
            fer, cer_0 = evaluate(feature_dev,
                                  dataset_dev,
                                  args.data.dev_size,
                                  G,
                                  beam_size=0,
                                  with_stamp=True)
            fer, cer = evaluate(feature_dev,
                                dataset_dev,
                                args.data.dev_size,
                                G,
                                beam_size=0,
                                with_stamp=False)
            with writer.as_default():
                tf.summary.scalar("performance/fer", fer, step=step)
                tf.summary.scalar("performance/cer_0", cer_0, step=step)
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], G)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Esempio n. 15
0
def train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_DataSet(
            trans_file=args.dirs.train.trans,
            align_file=args.dirs.train.align,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=True,
            transform=True)
        dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                        align_file=args.dirs.dev.align,
                                        uttid2wav=args.dirs.dev.wav_scp,
                                        feat_len_file=args.dirs.dev.feat_len,
                                        args=args,
                                        _shuffle=False,
                                        transform=True)
        dataset_train_supervise = ASR_align_DataSet(
            trans_file=args.dirs.train_supervise.trans,
            align_file=args.dirs.train_supervise.align,
            uttid2wav=args.dirs.train_supervise.wav_scp,
            feat_len_file=args.dirs.train_supervise.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        feature_train_supervise = TFData(
            dataset=dataset_train_supervise,
            dir_save=args.dirs.train_supervise.tfdata,
            args=args).read()
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()

        iter_feature_supervise = iter(
            feature_train_supervise.cache().repeat().padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        iter_feature_train = iter(
            feature_train.cache().repeat().shuffle(500).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

        dataset_text = TextDataSet(list_files=[args.dirs.lm.data],
                                   args=args,
                                   _shuffle=True)
        tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                      (tf.TensorShape([None])))
        iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map(
            lambda x: x[:args.model.D.max_label_len]).padded_batch(
                args.batch_size,
                ([args.model.D.max_label_len])).prefetch(buffer_size=5))

    # create model paremeters
    G = PhoneClassifier(args)
    D = PhoneDiscriminator3(args)
    G.summary()
    D.summary()

    optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr,
                                           beta_1=0.1,
                                           beta_2=0.5)
    optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr,
                                           beta_1=0.5,
                                           beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0
    best = 999
    phrase_ctc = False
    step_retention = 0
    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        ckpt.restore(args.dirs.checkpoint)
        print('checkpoint {} restored!!'.format(args.dirs.checkpoint))
        step = int(args.dirs.checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        if phrase_ctc:
            # CTC phrase
            uttids, x = next(iter_feature_supervise)
            trans = dataset_train_supervise.get_attrs('trans', uttids.numpy())
            loss_G_ctc = train_CTC_G(x, trans, G, D, optimizer_G)
        else:
            # GAN phrase
            ## D sub phrase
            for _ in range(args.opti.D_G_rate):
                uttids, x = next(iter_feature_train)
                text = next(iter_text)
                P_Real = tf.one_hot(text, args.dim_output)
                loss_D, loss_D_fake, loss_D_real, gp = train_D(
                    x, P_Real, text > 0, G, D, optimizer_D, args.lambda_gp,
                    args.model.D.max_label_len)
            ## G sub phrase
            uttids, x = next(iter_feature_train)
            _uttids, _x = next(iter_feature_supervise)
            _trans = dataset_train_supervise.get_attrs('trans',
                                                       _uttids.numpy())
            train_GAN_G(x, _x, _trans, G, D, optimizer_G,
                        args.model.D.max_label_len, args.lambda_supervise)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            if phrase_ctc:
                print(
                    'loss_G_ctc: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% iter: {}'
                    .format(loss_G_ctc, x.shape,
                            time() - start, progress * 100.0, step))
                with writer.as_default():
                    tf.summary.scalar("costs/loss_G_supervise",
                                      loss_G_ctc,
                                      step=step)
            else:
                print(
                    'loss_GAN: {:.3f}|{:.3f}|{:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'
                    .format(loss_D_fake, loss_D_real, gp, x.shape, text.shape,
                            time() - start, progress * 100.0, step))
        if step % args.dev_step == 0:
            cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G)
            if cer < best:
                best = cer
                G_values = save_varibales(G)
            elif step_retention < 2000:
                pass
            else:
                phrase_ctc = not phrase_ctc
                step_retention = 0
                load_values(G, G_values)
                print('====== switching phrase ======')
            with writer.as_default():
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], G)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1
        step_retention += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Esempio n. 16
0
def train():
    dataset_dev = ASR_align_DataSet(
        file=[args.dirs.dev.data],
        args=args,
        _shuffle=False,
        transform=True)
    with tf.device("/cpu:0"):
        # wav data
        tfdata_train = TFData(dataset=None,
                        dataAttr=['feature', 'label', 'align'],
                        dir_save=args.dirs.train.tfdata,
                        args=args).read(_shuffle=False)
        tfdata_dev = TFData(dataset=None,
                        dataAttr=['feature', 'label', 'align'],
                        dir_save=args.dirs.dev.tfdata,
                        args=args).read(_shuffle=False)

        x_0, y_0, _ = next(iter(tfdata_train.take(args.num_supervised).map(lambda x, y, z: (x, y, z[:args.max_seq_len])).\
            padded_batch(args.num_supervised, ([None, args.dim_input], [None], [None]))))
        iter_train = iter(tfdata_train.cache().repeat().shuffle(3000).map(lambda x, y, z: (x, y, z[:args.max_seq_len])).\
            padded_batch(args.batch_size, ([None, args.dim_input], [None], [args.max_seq_len])).prefetch(buffer_size=3))
        tfdata_dev = tfdata_dev.padded_batch(args.batch_size, ([None, args.dim_input], [None], [None]))

        # text data
        dataset_text = TextDataSet(
            list_files=[args.dirs.lm.data],
            args=args,
            _shuffle=True)

        tfdata_train_text = tf.data.Dataset.from_generator(
            dataset_text, (tf.int32), (tf.TensorShape([None])))
        iter_text = iter(tfdata_train_text.cache().repeat().shuffle(100).map(lambda x: x[:args.max_seq_len]).padded_batch(args.batch_size,
            ([args.max_seq_len])).prefetch(buffer_size=5))

    # create model paremeters
    G = PhoneClassifier(args)
    D = PhoneDiscriminator2(args)
    G.summary()
    D.summary()
    optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.5, beta_2=0.9)
    optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr, beta_1=0.5, beta_2=0.9)
    optimizer = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.5, beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer_G = optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=5)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt, args.dirs.checkpoint, max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        for _ in range(args.opti.D_G_rate):
            x, _, aligns = next(iter_train)
            text = next(iter_text)
            P_Real = tf.one_hot(text, args.dim_output)
            cost_D, gp = train_D(x, aligns, P_Real, text>0, G, D, optimizer_D, args.lambda_gp)

        x, _, aligns = next(iter_train)
        cost_G, fs = train_G(x, aligns, G, D, optimizer_G, args.lambda_fs)
        loss_supervise = train_G_supervised(x_0, y_0, G, optimizer_G, args.dim_output)

        num_processed += len(x)
        if step % 10 == 0:
            print('cost_G: {:.3f}|{:.3f}\tcost_D: {:.3f}|{:.3f}\tloss_supervise: {:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'.format(
                   cost_G, fs, cost_D, gp, loss_supervise, x.shape, text.shape, time()-start, progress*100.0, step))
            with writer.as_default():
                tf.summary.scalar("costs/cost_G", cost_G, step=step)
                tf.summary.scalar("costs/cost_D", cost_D, step=step)
                tf.summary.scalar("costs/gp", gp, step=step)
                tf.summary.scalar("costs/fs", fs, step=step)
                tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step)
        if step % args.dev_step == 0:
            fer, cer = evaluation(tfdata_dev, args.data.dev_size, G)
            with writer.as_default():
                tf.summary.scalar("performance/fer", fer, step=step)
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            decode(dataset_dev[0], G)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
Esempio n. 17
0
def train(Model):
    # create dataset and dataloader
    with tf.device("/cpu:0"):
        tfdata_dev = TFData(dataset=None,
                            dataAttr=['feature', 'label', 'align'],
                            dir_save=args.dirs.dev.tfdata,
                            args=args).read(_shuffle=False)
        tfdata_monitor = TFData(dataset=None,
                                dataAttr=['feature', 'label', 'align'],
                                dir_save=args.dirs.train.tfdata,
                                args=args).read(_shuffle=False)
        tfdata_monitor = tfdata_monitor.cache().repeat().shuffle(
            500).padded_batch(args.batch_size,
                              ([None, args.dim_input
                                ], [None], [None])).prefetch(buffer_size=5)
        tfdata_iter = iter(tfdata_monitor)
        tfdata_dev = tfdata_dev.padded_batch(
            args.batch_size, ([None, args.dim_input], [None], [None]))

    # get dataset ngram
    ngram_py, total_num = read_ngram(args.data.k,
                                     args.dirs.ngram,
                                     args.token2idx,
                                     type='list')

    # create model paremeters
    opti = tf.keras.optimizers.SGD(0.5)
    model = Model(args, optimizer=opti, name='fc')
    model.summary()

    # save & reload
    ckpt = tf.train.Checkpoint(model=model, optimizer=opti)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    if args.dirs.restore:
        latest_checkpoint = tf.train.CheckpointManager(
            ckpt, args.dirs.restore, max_to_keep=1).latest_checkpoint
        ckpt.restore(latest_checkpoint)
        print('{} restored!!'.format(latest_checkpoint))

    best_rewards = -999
    start_time = datetime.now()
    fer = 1.0
    seed = 999
    step = 0

    while 1:
        if fer < 0.69:
            break
        elif fer > 0.76 or step > 69:
            print('{}-th reset, pre FER: {:.3f}'.format(seed, fer))
            seed += 1
            step = 0
            tf.random.set_seed(seed)
            model = Model(args, optimizer=opti, name='fc')
            head_tail_constrain(next(tfdata_iter), model, opti)
            fer = mini_eva(tfdata_dev, model)
            ngram_sampled = sample(ngram_py, args.data.top_k)
            kernel, py = ngram2kernel(ngram_sampled, args)
        else:
            step += 1
            loss = train_step(model, tfdata_iter, kernel, py)
            fer = mini_eva(tfdata_dev, model)
            print('\tloss: {:.3f}\tFER: {:.3f}'.format(loss, fer))

    for global_step in range(99999):
        run_model_time = time()
        loss = train_step(model, tfdata_iter, kernel, py)

        used_time = time() - run_model_time
        if global_step % 1 == 0:
            print('full training loss: {:.3f}, spend {:.2f}s step {}'.format(
                loss, used_time, global_step))

        if global_step % args.dev_step == 0:
            evaluation(tfdata_dev, model)
        if global_step % args.decode_step == 0:
            decode(model)
        if global_step % args.fs_step == 0:
            fs_constrain(next(tfdata_iter), model, opti)
        if global_step % args.save_step == 0:
            ckpt_manager.save()

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Esempio n. 18
0
def Train():
    with tf.device("/cpu:0"):
        # dataset_train = ASR_align_ArkDataSet(
        #     scp_file=args.dirs.train.scp,
        #     trans_file=args.dirs.train.trans,
        #     align_file=None,
        #     feat_len_file=None,
        #     args=args,
        #     _shuffle=False,
        #     transform=False)
        # dataset_dev = ASR_align_ArkDataSet(
        #     scp_file=args.dirs.dev.scp,
        #     trans_file=args.dirs.dev.trans,
        #     align_file=None,
        #     feat_len_file=None,
        #     args=args,
        #     _shuffle=False,
        #     transform=False)
        dataset_train = ASR_align_DataSet(
            trans_file=args.dirs.train.trans,
            align_file=args.dirs.train.align,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                        uttid2wav=args.dirs.dev.wav_scp,
                                        align_file=args.dirs.dev.align,
                                        feat_len_file=args.dirs.dev.feat_len,
                                        args=args,
                                        _shuffle=False,
                                        transform=True)
        # wav data
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read(transform=True)
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read(transform=True)
        bucket = tf.data.experimental.bucket_by_sequence_length(
            element_length_func=lambda uttid, x: tf.shape(x)[0],
            bucket_boundaries=args.list_bucket_boundaries,
            bucket_batch_sizes=args.list_batch_size,
            padded_shapes=((), [None, args.dim_input]))
        iter_feature_train = iter(
            feature_train.repeat().shuffle(10).apply(bucket).prefetch(
                buffer_size=5))
        # iter_feature_train = iter(feature_train.repeat().shuffle(500).padded_batch(args.batch_size,
        #         ((), [None, args.dim_input])).prefetch(buffer_size=5))
        # feature_dev = feature_dev.apply(bucket).prefetch(buffer_size=5)
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

    # create model paremeters
    model = Model(args)
    model.summary()
    # lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    #     args.opti.lr,
    #     decay_steps=args.opti.decay_steps,
    #     decay_rate=0.5,
    #     staircase=True)
    optimizer = tf.keras.optimizers.Adam(args.opti.lr, beta_1=0.5, beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                   args.dirs.checkpoint,
                                                   max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    while step < 99999999:
        start = time()

        uttids, x = next(iter_feature_train)
        trans = dataset_train.get_attrs('trans', uttids.numpy())
        loss_supervise = train_CTC_supervised(x, trans, model, optimizer)

        num_processed += len(x)
        progress = num_processed / args.data.train_size

        if step % 10 == 0:
            print('loss: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% step: {}'.
                  format(loss_supervise, x.shape,
                         time() - start, progress * 100, step))
            with writer.as_default():
                tf.summary.scalar("costs/loss_supervise",
                                  loss_supervise,
                                  step=step)
        if step % args.dev_step == 0:
            cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, model)
            with writer.as_default():
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], model)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
def train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_DataSet(
            trans_file=args.dirs.train.trans,
            align_file=None,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=True,
            transform=True)
        dataset_dev = ASR_align_DataSet(
            trans_file=args.dirs.dev.trans,
            align_file=None,
            uttid2wav=args.dirs.dev.wav_scp,
            feat_len_file=args.dirs.dev.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        feature_train = TFData(dataset=dataset_train,
                        dir_save=args.dirs.train.tfdata,
                        args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                        dir_save=args.dirs.dev.tfdata,
                        args=args).read()
        iter_feature_train = iter(feature_train.cache().repeat().shuffle(500).padded_batch(args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size, ((), [None, args.dim_input]))

    # create model paremeters
    assigner = attentionAssign(args)
    G = PhoneClassifier(args, dim_input=args.model.attention.num_hidden)
    assigner.summary()
    G.summary()

    optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.9, beta_2=0.95)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=20)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt, args.dirs.checkpoint, max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        uttids, x = next(iter_feature_train)
        y = dataset_train.get_attrs('trans', uttids.numpy())
        ce_loss_supervise, quantity_loss_supervise, _ctc_loss = train_G_supervised(
            x, y, assigner, G, optimizer_G, args.dim_output, args.lambda_supervision)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            print('loss_supervise: {:.3f}|{:.3f}|{:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'.format(
                   ce_loss_supervise, quantity_loss_supervise, _ctc_loss, x.shape, None, time()-start, progress*100.0, step))
            # with writer.as_default():
            #     tf.summary.scalar("costs/cost_G", cost_G, step=step)
            #     tf.summary.scalar("costs/cost_D", cost_D, step=step)
            #     tf.summary.scalar("costs/gp", gp, step=step)
            #     tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step)
        if step % args.dev_step == 0:
            cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, assigner, G)
            with writer.as_default():
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], assigner, G)

        step += 1

    print('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
Esempio n. 20
0
def train_mul(Model):

    # create dataset and dataloader
    with tf.device("/cpu:0"):
        tfdata_dev = TFData(dataset=None,
                            dataAttr=['feature', 'label', 'align'],
                            dir_save=args.dirs.dev.tfdata,
                            args=args).read(_shuffle=False)
        tfdata_monitor = TFData(dataset=None,
                                dataAttr=['feature', 'label', 'align'],
                                dir_save=args.dirs.train.tfdata,
                                args=args).read(_shuffle=False)
        tfdata_monitor = tfdata_monitor.repeat().shuffle(500).padded_batch(
            args.batch_size,
            ([None, args.dim_input], [None], [None])).prefetch(buffer_size=5)
        tfdata_iter = iter(tfdata_monitor)
        tfdata_dev = tfdata_dev.padded_batch(
            args.batch_size, ([None, args.dim_input], [None], [None]))

    # get dataset ngram
    ngram_py, total_num = read_ngram(args.data.k,
                                     args.dirs.ngram,
                                     args.token2idx,
                                     type='list')

    # create model paremeters
    opti = tf.keras.optimizers.SGD(0.5)
    model = Model(args, optimizer=opti, name='fc')
    model.summary()

    def thread_session(thread_id, queue_input, queue_output):
        global kernel
        gpu = args.list_gpus[thread_id]
        with tf.device(gpu):
            opti_adam = build_optimizer(args, type='adam')
            model = Model(args,
                          optimizer=opti_adam,
                          name='fc' + str(thread_id))
            print('thread_{} is waiting to run on {}....'.format(
                thread_id, gpu))
            while True:
                # s = time()
                id, weights, x, aligns_sampled = queue_input.get()
                model.set_weights(weights)
                # t = time()
                logits = model(x, training=False)
                pz, K = model.EODM(logits, aligns_sampled, kernel)
                queue_output.put((id, pz, K))
                # print('{} {:.3f}|{:.3f}s'.format(gpu, t-s, time()-s))

    for id in range(args.num_gpus):
        thread = threading.Thread(target=thread_session,
                                  args=(id, queue_input, queue_output))
        thread.daemon = True
        thread.start()

    # save & reload
    ckpt = tf.train.Checkpoint(model=model, optimizer=opti)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    if args.dirs.restore:
        latest_checkpoint = tf.train.CheckpointManager(
            ckpt, args.dirs.restore, max_to_keep=1).latest_checkpoint
        ckpt.restore(latest_checkpoint)
        print('{} restored!!'.format(latest_checkpoint))

    best_rewards = -999
    start_time = datetime.now()
    fer = 1.0
    seed = 99999
    step = 0
    global aligns_sampled, kernel

    while 1:
        if fer < 0.69:
            break
        elif fer > 0.77 or step > 69:
            print('{}-th reset, pre FER: {:.3f}'.format(seed, fer))
            seed += 1
            step = 0
            tf.random.set_seed(seed)
            model = Model(args, optimizer=opti, name='fc')
            head_tail_constrain(next(tfdata_iter), model, opti)
            fer = mini_eva(tfdata_dev, model)
            ngram_sampled = sample(ngram_py, args.data.top_k)
            kernel, py = ngram2kernel(ngram_sampled, args)
        else:
            step += 1
            loss = train_step(model, tfdata_iter, py)
            fer = mini_eva(tfdata_dev, model)
            print('\tloss: {:.3f}\tFER: {:.3f}'.format(loss, fer))

    for global_step in range(99999):
        run_model_time = time()
        loss = train_step(model, tfdata_iter, py)

        used_time = time() - run_model_time
        if global_step % 1 == 0:
            print('full training loss: {:.3f}, spend {:.2f}s step {}'.format(
                loss, used_time, global_step))

        if global_step % args.dev_step == 0:
            evaluation(tfdata_dev, model)
        if global_step % args.decode_step == 0:
            decode(model)
        if global_step % args.fs_step == 0:
            fs_constrain(next(tfdata_iter), model, opti)
        if global_step % args.save_step == 0:
            ckpt_manager.save()

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Esempio n. 21
0
def Train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_ArkDataSet(scp_file=args.dirs.train.scp,
                                             trans_file=args.dirs.train.trans,
                                             align_file=None,
                                             feat_len_file=None,
                                             args=args,
                                             _shuffle=False,
                                             transform=False)
        dataset_dev = ASR_align_ArkDataSet(scp_file=args.dirs.dev.scp,
                                           trans_file=args.dirs.dev.trans,
                                           align_file=None,
                                           feat_len_file=None,
                                           args=args,
                                           _shuffle=False,
                                           transform=False)
        # wav data
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()
        bucket = tf.data.experimental.bucket_by_sequence_length(
            element_length_func=lambda uttid, x: tf.shape(x)[0],
            bucket_boundaries=args.list_bucket_boundaries,
            bucket_batch_sizes=args.list_batch_size,
            padded_shapes=((), [None, args.dim_input]))
        iter_feature_train = iter(
            feature_train.repeat().shuffle(500).apply(bucket).prefetch(
                buffer_size=5))
        # iter_feature_train = iter(feature_train.repeat().shuffle(500).padded_batch(args.batch_size,
        #         ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

    stratedy = tf.distribute.MirroredStrategy(
        devices=["device:GPU:0", "device:GPU:1"])
    with stratedy.scope():
        # create model paremeters
        model = conv_lstm(args)
        model.summary()
        optimizer = tf.keras.optimizers.Adam(args.opti.lr,
                                             beta_1=0.5,
                                             beta_2=0.9)

        writer = tf.summary.create_file_writer(str(args.dir_log))
        ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
        ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                  args.dir_checkpoint,
                                                  max_to_keep=20)
        step = 0

        # if a checkpoint exists, restore the latest checkpoint.
        if args.dirs.checkpoint:
            _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                       args.dirs.checkpoint,
                                                       max_to_keep=1)
            ckpt.restore(_ckpt_manager.latest_checkpoint)
            print('checkpoint {} restored!!'.format(
                _ckpt_manager.latest_checkpoint))
            step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    while step < 99999999:
        start = time()

    @tf.function(experimental_relax_shapes=True)
    def _train():
        for _ in tf.range(1000):
            stratedy.experimental_run_v2(train_CTC_supervised,
                                         args=(next(iter_feature_train),
                                               dataset_train, model,
                                               optimizer))
        # loss_supervise = tf.reduce_mean(res._values)
        # loss_supervise = train_CTC_supervised(x, trans, model, optimizer)

        # num_processed += len(x)
        # progress = num_processed / args.data.train_size

        # if step % 10 == 0:
        #     print('loss: {:.3f}\tused: {:.3f}\t step: {}'.format(
        #            loss_supervise, time()-start, step))
        #     with writer.as_default():
        #         tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step)
        # if step % args.dev_step == 0:
        #     cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, model)
        #     with writer.as_default():
        #         tf.summary.scalar("performance/cer", cer, step=step)
        # if step % args.decode_step == 0:
        #     monitor(dataset_dev[0], model)
        # if step % args.save_step == 0:
        #     save_path = ckpt_manager.save(step)
        #     print('save model {}'.format(save_path))
        #
        # step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Esempio n. 22
0
def train(Model):
    # load external LM
    with tf.device("/cpu:0"):
        dataset_dev = ASR_align_DataSet(
            file=[args.dirs.dev.data],
            args=args,
            _shuffle=False,
            transform=True)
        tfdata_train = TFData(dataset=None,
                        dataAttr=['feature', 'label', 'align'],
                        dir_save=args.dirs.train.tfdata,
                        args=args).read(_shuffle=False)
        tfdata_dev = TFData(dataset=None,
                        dataAttr=['feature', 'label', 'align'],
                        dir_save=args.dirs.dev.tfdata,
                        args=args).read(_shuffle=False)

        x_0, y_0, aligns_0 = next(iter(tfdata_train.take(args.num_supervised).\
            padded_batch(args.num_supervised, ([None, args.dim_input], [None], [None]))))
        iter_train = iter(tfdata_train.cache().repeat().shuffle(3000).\
            padded_batch(args.batch_size, ([None, args.dim_input], [None], [None])).prefetch(buffer_size=3))
        tfdata_dev = tfdata_dev.padded_batch(args.batch_size, ([None, args.dim_input], [None], [None]))

    # get dataset ngram
    ngram_py, total_num = read_ngram(args.data.k, args.dirs.ngram, args.token2idx, type='list')
    kernel, py = ngram2kernel(ngram_py, args)

    # create model paremeters
    model = Model(args)
    compute_p_ngram = P_Ngram(kernel, args)
    model.summary()
    compute_p_ngram.summary()

    # build optimizer
    if args.opti.type == 'adam':
        optimizer = tf.keras.optimizers.Adam(args.opti.lr, beta_1=0.5, beta_2=0.9)
        # optimizer = tf.keras.optimizers.Adam(args.opti.lr*0.1, beta_1=0.5, beta_2=0.9)
    elif args.opti.type == 'sgd':
        optimizer = tf.keras.optimizers.SGD(lr=args.opti.lr, momentum=0.9, decay=0.98)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(model=model, optimizer = optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=5)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt, args.dirs.checkpoint, max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print ('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    # step = 1600
    while step < 99999999:
        start = time()

        x, _, aligns = next(iter_train)
        loss_EODM, loss_fs = train_step(x, aligns, py, model, compute_p_ngram, optimizer, args.lambda_fs)
        loss_supervise = train_G_supervised(x_0, y_0, model, optimizer, args.dim_output)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            print('EODM loss: {:.2f}\tloss_fs: {:.3f} * {}\tloss_supervise: {:.3f} * {}\tbatch: {} time: {:.2f} s {:.3f}% step: {}'.format(
                   loss_EODM, loss_fs, args.lambda_fs, loss_supervise, args.lambda_supervision, x.shape, time()-start, progress*100.0, step))
            with writer.as_default():
                tf.summary.scalar("costs/loss_EODM", loss_EODM, step=step)
                tf.summary.scalar("costs/loss_fs", loss_fs, step=step)
                tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step)

        if step % args.dev_step == 0:
            fer, cer = evaluation(tfdata_dev, args.data.dev_size, model)
            with writer.as_default():
                tf.summary.scalar("performance/fer", fer, step=step)
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            decode(dataset_dev[0], model)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
def Train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_ArkDataSet(scp_file=args.dirs.train.scp,
                                             trans_file=args.dirs.train.trans,
                                             align_file=None,
                                             feat_len_file=None,
                                             args=args,
                                             _shuffle=False,
                                             transform=True)

        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read(_shuffle=True, transform=True)
        bucket = tf.data.experimental.bucket_by_sequence_length(
            element_length_func=lambda uttid, x: tf.shape(x)[0],
            bucket_boundaries=args.list_bucket_boundaries,
            bucket_batch_sizes=args.list_batch_size,
            padded_shapes=((), [None, args.dim_input]))
        iter_feature_train = iter(
            feature_train.repeat().shuffle(10).apply(bucket).prefetch(
                buffer_size=5))

    # create model paremeters
    model = Transformer(args)
    model.summary()

    # learning_rate = CustomSchedule(args.model.G.d_model)
    # lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    #     0.0001,
    #     decay_steps=10000,
    #     decay_rate=0.5,
    #     staircase=True)
    optimizer = tf.keras.optimizers.Adam(0.000005,
                                         beta_1=0.5,
                                         beta_2=0.9,
                                         epsilon=1e-9)
    # optimizer = tf.keras.optimizers.SGD(0.1)

    ckpt = tf.train.Checkpoint(model=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                   args.dirs.checkpoint,
                                                   max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    # uttids, x = next(iter_feature_train)
    # trans_sos = dataset_train.get_attrs('trans_sos', uttids.numpy())
    # trans_eos = dataset_train.get_attrs('trans_eos', uttids.numpy())
    while step < 99999999:
        start = time()
        uttids, x = next(iter_feature_train)
        trans_sos = dataset_train.get_attrs('trans_sos', uttids.numpy())
        trans_eos = dataset_train.get_attrs('trans_eos', uttids.numpy())
        loss_supervise = train_step(x, trans_sos, trans_eos, model, optimizer)

        num_processed += len(x)
        progress = num_processed / args.data.train_size

        if step % 20 == 0:
            print('loss: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% step: {}'.
                  format(loss_supervise, x.shape,
                         time() - start, progress * 100, step))
        # if step % args.save_step == 0:
        #     save_path = ckpt_manager.save(step)
        #     print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))