Beispiel #1
0
def train():
    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFReader(args.dirs.train.tfdata, args=args)
    batch_train = dataReader_train.fentch_batch_bucket()
    dataReader_untrain = TFReader(args.dirs.untrain.tfdata, args=args)
    batch_untrain = dataReader_untrain.fentch_batch(args.batch_size)
    # batch_untrain = dataReader_untrain.fentch_batch_bucket()
    args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata)
    args.data.untrain_size = TFData.read_tfdata_info(
        args.dirs.untrain.tfdata)['size_dataset']

    dataset_text = TextDataSet(list_files=[args.dirs.text.data],
                               args=args,
                               _shuffle=True)
    tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                  (tf.TensorShape([None])))
    iter_text = tfdata_train.cache().repeat().shuffle(1000).\
        padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()

    feat, label = readTFRecord(args.dirs.dev.tfdata,
                               args,
                               _shuffle=False,
                               transform=True)
    dataloader_dev = ASRDataLoader(args.dataset_dev,
                                   args,
                                   feat,
                                   label,
                                   batch_size=args.batch_size,
                                   num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()
    tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False)

    G = args.Model(tensor_global_step,
                   encoder=args.model.encoder.type,
                   decoder=args.model.decoder.type,
                   batch=batch_train,
                   training=True,
                   args=args)

    G_infer = args.Model(tensor_global_step,
                         encoder=args.model.encoder.type,
                         decoder=args.model.decoder.type,
                         training=False,
                         args=args)
    vars_G = G.trainable_variables

    D = args.Model_D(tensor_global_step1,
                     training=True,
                     name='discriminator',
                     args=args)

    gan = args.GAN([tensor_global_step0, tensor_global_step1],
                   G,
                   D,
                   batch=batch_train,
                   unbatch=batch_untrain,
                   name='GAN',
                   args=args)

    size_variables()

    start_time = datetime.now()
    saver_G = tf.train.Saver(vars_G, max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        dataloader_dev.sess = sess
        if args.dirs.checkpoint_G:
            saver_G.restore(sess, args.dirs.checkpoint_G)

        for i in range(2000):
            ctc_loss, *_ = sess.run(G.list_run)
            if i % 400:
                print('ctc_loss: {:.2f}'.format(ctc_loss))

        batch_time = time()
        num_processed = 0
        num_processed_unbatch = 0
        progress = 0
        while progress < args.num_epochs:

            global_step, lr_G, lr_D = sess.run([
                tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D
            ])

            # untrain
            text = sess.run(iter_text)
            text_lens = get_batch_length(text)
            shape_text = text.shape
            # loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D,
            #                                               feed_dict={gan.list_pl[0]:text,
            #                                                          gan.list_pl[1]:text_lens})
            loss_D = loss_D_res = loss_D_text = loss_gp = 0
            (loss_G, loss_G_supervise, _), (shape_feature, shape_unfeature) = \
                sess.run([gan.list_train_G, gan.list_feature_shape])

            num_processed += shape_feature[0]
            num_processed_unbatch += shape_unfeature[0]
            used_time = time() - batch_time
            batch_time = time()
            progress = num_processed / args.data.train_size
            progress_unbatch = num_processed_unbatch / args.data.untrain_size

            if global_step % 20 == 0:
                print(
                    'loss_supervise: {:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}|{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.1f}%|{:.1f}% step: {}'
                    .format(loss_G_supervise, loss_D_res, loss_D_text, loss_gp,
                            shape_feature, shape_unfeature, shape_text, lr_G,
                            lr_D, used_time, progress * 100.0,
                            progress_unbatch * 100.0, global_step))
                # summary.summary_scalar('loss_G', loss_G, global_step)
                # summary.summary_scalar('loss_D', loss_D, global_step)
                # summary.summary_scalar('lr_G', lr_G, global_step)
                # summary.summary_scalar('lr_D', lr_D, global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess),
                           str(args.dir_checkpoint / 'model'),
                           global_step=global_step,
                           write_meta_graph=True)

            # if global_step % args.dev_step == args.dev_step - 1:
            if global_step % args.dev_step == 0:
                cer, wer = dev(step=global_step,
                               dataloader=dataloader_dev,
                               model=G_infer,
                               sess=sess,
                               unit=args.data.unit,
                               idx2token=args.idx2token,
                               token2idx=args.token2idx)
                # summary.summary_scalar('dev_cer', cer, global_step)
                # summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
                decode_test(step=global_step,
                            sample=args.dataset_test[10],
                            model=G_infer,
                            sess=sess,
                            unit=args.data.unit,
                            idx2token=args.idx2token,
                            token2idx=args.token2idx)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
def train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_DataSet(
            trans_file=args.dirs.train.trans,
            align_file=args.dirs.train.align,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=True,
            transform=True)
        dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                        align_file=args.dirs.dev.align,
                                        uttid2wav=args.dirs.dev.wav_scp,
                                        feat_len_file=args.dirs.dev.feat_len,
                                        args=args,
                                        _shuffle=False,
                                        transform=True)
        dataset_train_supervise = ASR_align_DataSet(
            trans_file=args.dirs.train_supervise.trans,
            align_file=args.dirs.train_supervise.align,
            uttid2wav=args.dirs.train_supervise.wav_scp,
            feat_len_file=args.dirs.train_supervise.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        feature_train_supervise = TFData(
            dataset=dataset_train_supervise,
            dir_save=args.dirs.train_supervise.tfdata,
            args=args).read()
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()

        iter_feature_supervise = iter(
            feature_train_supervise.cache().repeat().padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        iter_feature_train = iter(
            feature_train.cache().repeat().shuffle(500).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

        dataset_text = TextDataSet(list_files=[args.dirs.lm.data],
                                   args=args,
                                   _shuffle=True)
        tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                      (tf.TensorShape([None])))
        iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map(
            lambda x: x[:args.model.D.max_label_len]).padded_batch(
                args.batch_size,
                ([args.model.D.max_label_len])).prefetch(buffer_size=5))

    # create model paremeters
    G = PhoneClassifier(args)
    D = PhoneDiscriminator3(args)
    G.summary()
    D.summary()

    optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr,
                                           beta_1=0.1,
                                           beta_2=0.5)
    optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr,
                                           beta_1=0.5,
                                           beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0
    best = 999
    phrase_ctc = False
    step_retention = 0
    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        ckpt.restore(args.dirs.checkpoint)
        print('checkpoint {} restored!!'.format(args.dirs.checkpoint))
        step = int(args.dirs.checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        if phrase_ctc:
            # CTC phrase
            uttids, x = next(iter_feature_supervise)
            trans = dataset_train_supervise.get_attrs('trans', uttids.numpy())
            loss_G_ctc = train_CTC_G(x, trans, G, D, optimizer_G)
        else:
            # GAN phrase
            ## D sub phrase
            for _ in range(args.opti.D_G_rate):
                uttids, x = next(iter_feature_train)
                text = next(iter_text)
                P_Real = tf.one_hot(text, args.dim_output)
                loss_D, loss_D_fake, loss_D_real, gp = train_D(
                    x, P_Real, text > 0, G, D, optimizer_D, args.lambda_gp,
                    args.model.D.max_label_len)
            ## G sub phrase
            uttids, x = next(iter_feature_train)
            _uttids, _x = next(iter_feature_supervise)
            _trans = dataset_train_supervise.get_attrs('trans',
                                                       _uttids.numpy())
            train_GAN_G(x, _x, _trans, G, D, optimizer_G,
                        args.model.D.max_label_len, args.lambda_supervise)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            if phrase_ctc:
                print(
                    'loss_G_ctc: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% iter: {}'
                    .format(loss_G_ctc, x.shape,
                            time() - start, progress * 100.0, step))
                with writer.as_default():
                    tf.summary.scalar("costs/loss_G_supervise",
                                      loss_G_ctc,
                                      step=step)
            else:
                print(
                    'loss_GAN: {:.3f}|{:.3f}|{:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'
                    .format(loss_D_fake, loss_D_real, gp, x.shape, text.shape,
                            time() - start, progress * 100.0, step))
        if step % args.dev_step == 0:
            cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G)
            if cer < best:
                best = cer
                G_values = save_varibales(G)
            elif step_retention < 2000:
                pass
            else:
                phrase_ctc = not phrase_ctc
                step_retention = 0
                load_values(G, G_values)
                print('====== switching phrase ======')
            with writer.as_default():
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], G)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1
        step_retention += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Beispiel #3
0
def train():
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_DataSet(
            trans_file=args.dirs.train.trans,
            align_file=args.dirs.train.align,
            uttid2wav=args.dirs.train.wav_scp,
            feat_len_file=args.dirs.train.feat_len,
            args=args,
            _shuffle=True,
            transform=True)
        dataset_dev = ASR_align_DataSet(trans_file=args.dirs.dev.trans,
                                        align_file=args.dirs.dev.align,
                                        uttid2wav=args.dirs.dev.wav_scp,
                                        feat_len_file=args.dirs.dev.feat_len,
                                        args=args,
                                        _shuffle=False,
                                        transform=True)
        dataset_train_supervise = ASR_align_DataSet(
            trans_file=args.dirs.train_supervise.trans,
            align_file=args.dirs.train_supervise.align,
            uttid2wav=args.dirs.train_supervise.wav_scp,
            feat_len_file=args.dirs.train_supervise.feat_len,
            args=args,
            _shuffle=False,
            transform=True)
        feature_train_supervise = TFData(
            dataset=dataset_train_supervise,
            dir_save=args.dirs.train_supervise.tfdata,
            args=args).read()
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()
        supervise_uttids, supervise_x = next(iter(feature_train_supervise.take(args.num_supervised).\
            padded_batch(args.num_supervised, ((), [None, args.dim_input]))))
        supervise_aligns = dataset_train_supervise.get_attrs(
            'align', supervise_uttids.numpy())
        supervise_bounds = dataset_train_supervise.get_attrs(
            'bounds', supervise_uttids.numpy())

        iter_feature_train = iter(
            feature_train.cache().repeat().shuffle(500).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

        dataset_text = TextDataSet(list_files=[args.dirs.lm.data],
                                   args=args,
                                   _shuffle=True)
        tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                      (tf.TensorShape([None])))
        iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map(
            lambda x: x[:args.model.D.max_label_len]).padded_batch(
                args.batch_size,
                ([args.model.D.max_label_len])).prefetch(buffer_size=5))

    # create model paremeters
    G = PhoneClassifier(args)
    D = PhoneDiscriminator3(args)
    G.summary()
    D.summary()

    optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr,
                                           beta_1=0.5,
                                           beta_2=0.9)
    optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr,
                                           beta_1=0.5,
                                           beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer_G=optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                   args.dirs.checkpoint,
                                                   max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        for _ in range(args.opti.D_G_rate):
            uttids, x = next(iter_feature_train)
            stamps = dataset_train.get_attrs('stamps', uttids.numpy())
            text = next(iter_text)
            P_Real = tf.one_hot(text, args.dim_output)
            cost_D, gp = train_D(x, stamps, P_Real, text > 0, G, D,
                                 optimizer_D, args.lambda_gp,
                                 args.model.D.max_label_len)
            # cost_D, gp = train_D(x, P_Real, text>0, G, D, optimizer_D,
            #                      args.lambda_gp, args.model.G.len_seq, args.model.D.max_label_len)

        uttids, x = next(iter_feature_train)
        stamps = dataset_train.get_attrs('stamps', uttids.numpy())
        cost_G, fs = train_G(x, stamps, G, D, optimizer_G, args.lambda_fs)
        # cost_G, fs = train_G(x, G, D, optimizer_G,
        #                      args.lambda_fs, args.model.G.len_seq, args.model.D.max_label_len)

        loss_supervise = train_G_supervised(supervise_x, supervise_aligns, G,
                                            optimizer_G, args.dim_output,
                                            args.lambda_supervision)
        # loss_supervise, bounds_loss = train_G_bounds_supervised(
        #     supervise_x, supervise_bounds, supervise_aligns, G, optimizer_G, args.dim_output)

        num_processed += len(x)
        progress = num_processed / args.data.train_size
        if step % 10 == 0:
            print(
                'cost_G: {:.3f}|{:.3f}\tcost_D: {:.3f}|{:.3f}\tloss_supervise: {:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'
                .format(cost_G, fs, cost_D, gp, loss_supervise, x.shape,
                        text.shape,
                        time() - start, progress * 100.0, step))
            with writer.as_default():
                tf.summary.scalar("costs/cost_G", cost_G, step=step)
                tf.summary.scalar("costs/cost_D", cost_D, step=step)
                tf.summary.scalar("costs/gp", gp, step=step)
                tf.summary.scalar("costs/fs", fs, step=step)
                tf.summary.scalar("costs/loss_supervise",
                                  loss_supervise,
                                  step=step)
        if step % args.dev_step == 0:
            # fer, cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, G)
            fer, cer_0 = evaluate(feature_dev,
                                  dataset_dev,
                                  args.data.dev_size,
                                  G,
                                  beam_size=0,
                                  with_stamp=True)
            fer, cer = evaluate(feature_dev,
                                dataset_dev,
                                args.data.dev_size,
                                G,
                                beam_size=0,
                                with_stamp=False)
            with writer.as_default():
                tf.summary.scalar("performance/fer", fer, step=step)
                tf.summary.scalar("performance/cer_0", cer_0, step=step)
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], G)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
def train_gan():
    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args)
    batch_train = dataReader_train.fentch_multi_batch_bucket()
    dataReader_untrain = TFDataReader(args.dirs.untrain.tfdata, args=args)
    batch_untrain = dataReader_untrain.fentch_multi_batch(args.batch_size)
    args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata)
    args.data.untrain_size = TFDataReader.read_tfdata_info(
        args.dirs.untrain.tfdata)['size_dataset']

    dataset_text = TextDataSet(list_files=[args.dirs.text.data],
                               args=args,
                               _shuffle=True)
    tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                  (tf.TensorShape([None])))
    iter_text = tfdata_train.cache().repeat().shuffle(1000).\
        padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()
    dataReader_dev = TFDataReader(args.dirs.dev.tfdata,
                                  args=args,
                                  _shuffle=False,
                                  transform=True)
    dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev,
                                          args,
                                          dataReader_dev.feat,
                                          dataReader_dev.phone,
                                          dataReader_dev.label,
                                          batch_size=args.batch_size,
                                          num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()
    tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False)

    G = args.Model(tensor_global_step,
                   encoder=args.model.encoder.type,
                   encoder2=args.model.encoder2.type,
                   decoder=args.model.decoder.type,
                   batch=batch_train,
                   training=True,
                   args=args)

    G_infer = args.Model(tensor_global_step,
                         encoder=args.model.encoder.type,
                         encoder2=args.model.encoder2.type,
                         decoder=args.model.decoder.type,
                         training=False,
                         args=args)
    vars_ASR = G.trainable_variables()
    # vars_G_ocd = G.trainable_variables('Ectc_Docd/ocd_decoder')

    D = args.Model_D(tensor_global_step1,
                     training=True,
                     name='discriminator',
                     args=args)

    gan = args.GAN([tensor_global_step0, tensor_global_step1],
                   G,
                   D,
                   batch=batch_train,
                   unbatch=batch_untrain,
                   name='GAN',
                   args=args)

    size_variables()

    start_time = datetime.now()
    saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=10)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        dataloader_dev.sess = sess
        if args.dirs.checkpoint_G:
            saver_ASR.restore(sess, args.dirs.checkpoint_G)

        batch_time = time()
        num_processed = 0
        num_processed_unbatch = 0
        progress = 0
        while progress < args.num_epochs:

            # semi_supervise
            global_step, lr_G, lr_D = sess.run([
                tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D
            ])

            for _ in range(3):
                text = sess.run(iter_text)
                text_lens = get_batch_length(text)
                shape_text = text.shape
                loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(
                    gan.list_train_D,
                    feed_dict={
                        gan.list_pl[0]: text,
                        gan.list_pl[1]: text_lens
                    })
                # loss_D=loss_D_res=loss_D_text=loss_gp=0
            (loss_G, ctc_loss, ce_loss, _), (shape_batch, shape_unbatch) = \
                sess.run([gan.list_train_G, gan.list_feature_shape])

            num_processed += shape_batch[0]
            # num_processed_unbatch += shape_unbatch[0]
            used_time = time() - batch_time
            batch_time = time()
            progress = num_processed / args.data.train_size
            progress_unbatch = num_processed_unbatch / args.data.untrain_size

            if global_step % 40 == 0:
                print('ctc|ce loss: {:.2f}|{:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\t{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.3f}% step: {}'.format(
                       np.mean(ctc_loss), np.mean(ce_loss), loss_D_res, loss_D_text, loss_gp, shape_batch, \
                       shape_unbatch, lr_G, lr_D, used_time, progress*100, global_step))
                summary.summary_scalar('ctc_loss', np.mean(ctc_loss),
                                       global_step)
                summary.summary_scalar('ce_loss', np.mean(ce_loss),
                                       global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver_ASR.save(get_session(sess),
                               str(args.dir_checkpoint / 'model_G'),
                               global_step=global_step,
                               write_meta_graph=True)
                print(
                    'saved G in',
                    str(args.dir_checkpoint) + '/model_G-' + str(global_step))
                # saver_G_en.save(get_session(sess), str(args.dir_checkpoint/'model_G_en'), global_step=global_step, write_meta_graph=True)
                # print('saved model in',  str(args.dir_checkpoint)+'/model_G_en-'+str(global_step))

            if global_step % args.dev_step == args.dev_step - 1:
                # if True:
                per, cer, wer = dev(step=global_step,
                                    dataloader=dataloader_dev,
                                    model=G_infer,
                                    sess=sess,
                                    unit=args.data.unit,
                                    args=args)
                summary.summary_scalar('dev_per', per, global_step)
                summary.summary_scalar('dev_cer', cer, global_step)
                summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
                # if True:
                decode_test(step=global_step,
                            sample=args.dataset_test.uttid2sample(
                                args.sample_uttid),
                            model=G_infer,
                            sess=sess,
                            unit=args.data.unit,
                            args=args)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Beispiel #5
0
def train():
    dataset_dev = ASR_align_DataSet(
        file=[args.dirs.dev.data],
        args=args,
        _shuffle=False,
        transform=True)
    with tf.device("/cpu:0"):
        # wav data
        tfdata_train = TFData(dataset=None,
                        dataAttr=['feature', 'label', 'align'],
                        dir_save=args.dirs.train.tfdata,
                        args=args).read(_shuffle=False)
        tfdata_dev = TFData(dataset=None,
                        dataAttr=['feature', 'label', 'align'],
                        dir_save=args.dirs.dev.tfdata,
                        args=args).read(_shuffle=False)

        x_0, y_0, _ = next(iter(tfdata_train.take(args.num_supervised).map(lambda x, y, z: (x, y, z[:args.max_seq_len])).\
            padded_batch(args.num_supervised, ([None, args.dim_input], [None], [None]))))
        iter_train = iter(tfdata_train.cache().repeat().shuffle(3000).map(lambda x, y, z: (x, y, z[:args.max_seq_len])).\
            padded_batch(args.batch_size, ([None, args.dim_input], [None], [args.max_seq_len])).prefetch(buffer_size=3))
        tfdata_dev = tfdata_dev.padded_batch(args.batch_size, ([None, args.dim_input], [None], [None]))

        # text data
        dataset_text = TextDataSet(
            list_files=[args.dirs.lm.data],
            args=args,
            _shuffle=True)

        tfdata_train_text = tf.data.Dataset.from_generator(
            dataset_text, (tf.int32), (tf.TensorShape([None])))
        iter_text = iter(tfdata_train_text.cache().repeat().shuffle(100).map(lambda x: x[:args.max_seq_len]).padded_batch(args.batch_size,
            ([args.max_seq_len])).prefetch(buffer_size=5))

    # create model paremeters
    G = PhoneClassifier(args)
    D = PhoneDiscriminator2(args)
    G.summary()
    D.summary()
    optimizer_G = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.5, beta_2=0.9)
    optimizer_D = tf.keras.optimizers.Adam(args.opti.D.lr, beta_1=0.5, beta_2=0.9)
    optimizer = tf.keras.optimizers.Adam(args.opti.G.lr, beta_1=0.5, beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt = tf.train.Checkpoint(G=G, optimizer_G = optimizer_G)
    ckpt_manager = tf.train.CheckpointManager(ckpt, args.dir_checkpoint, max_to_keep=5)
    step = 0

    # if a checkpoint exists, restore the latest checkpoint.
    if args.dirs.checkpoint:
        _ckpt_manager = tf.train.CheckpointManager(ckpt, args.dirs.checkpoint, max_to_keep=1)
        ckpt.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint {} restored!!'.format(_ckpt_manager.latest_checkpoint))
        step = int(_ckpt_manager.latest_checkpoint.split('-')[-1])

    start_time = datetime.now()
    num_processed = 0
    progress = 0

    while step < 99999999:
        start = time()

        for _ in range(args.opti.D_G_rate):
            x, _, aligns = next(iter_train)
            text = next(iter_text)
            P_Real = tf.one_hot(text, args.dim_output)
            cost_D, gp = train_D(x, aligns, P_Real, text>0, G, D, optimizer_D, args.lambda_gp)

        x, _, aligns = next(iter_train)
        cost_G, fs = train_G(x, aligns, G, D, optimizer_G, args.lambda_fs)
        loss_supervise = train_G_supervised(x_0, y_0, G, optimizer_G, args.dim_output)

        num_processed += len(x)
        if step % 10 == 0:
            print('cost_G: {:.3f}|{:.3f}\tcost_D: {:.3f}|{:.3f}\tloss_supervise: {:.3f}\tbatch: {}|{}\tused: {:.3f}\t {:.3f}% iter: {}'.format(
                   cost_G, fs, cost_D, gp, loss_supervise, x.shape, text.shape, time()-start, progress*100.0, step))
            with writer.as_default():
                tf.summary.scalar("costs/cost_G", cost_G, step=step)
                tf.summary.scalar("costs/cost_D", cost_D, step=step)
                tf.summary.scalar("costs/gp", gp, step=step)
                tf.summary.scalar("costs/fs", fs, step=step)
                tf.summary.scalar("costs/loss_supervise", loss_supervise, step=step)
        if step % args.dev_step == 0:
            fer, cer = evaluation(tfdata_dev, args.data.dev_size, G)
            with writer.as_default():
                tf.summary.scalar("performance/fer", fer, step=step)
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            decode(dataset_dev[0], G)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))
Beispiel #6
0
def Train():
    args.data.untrain_size = TFData.read_tfdata_info(
        args.dirs.untrain.tfdata)['size_dataset']
    with tf.device("/cpu:0"):
        dataset_train = ASR_align_ArkDataSet(scp_file=args.dirs.train.scp,
                                             trans_file=args.dirs.train.trans,
                                             align_file=None,
                                             feat_len_file=None,
                                             args=args,
                                             _shuffle=False,
                                             transform=False)
        dataset_untrain = ASR_align_ArkDataSet(scp_file=args.dirs.untrain.scp,
                                               trans_file=None,
                                               align_file=None,
                                               feat_len_file=None,
                                               args=args,
                                               _shuffle=False,
                                               transform=False)
        dataset_dev = ASR_align_ArkDataSet(scp_file=args.dirs.dev.scp,
                                           trans_file=args.dirs.dev.trans,
                                           align_file=None,
                                           feat_len_file=None,
                                           args=args,
                                           _shuffle=False,
                                           transform=False)
        # wav data
        feature_train = TFData(dataset=dataset_train,
                               dir_save=args.dirs.train.tfdata,
                               args=args).read()
        feature_unsupervise = TFData(dataset=dataset_untrain,
                                     dir_save=args.dirs.untrain.tfdata,
                                     args=args).read()
        feature_dev = TFData(dataset=dataset_dev,
                             dir_save=args.dirs.dev.tfdata,
                             args=args).read()
        bucket = tf.data.experimental.bucket_by_sequence_length(
            element_length_func=lambda uttid, x: tf.shape(x)[0],
            bucket_boundaries=args.list_bucket_boundaries,
            bucket_batch_sizes=args.list_batch_size,
            padded_shapes=((), [None, args.dim_input]))
        iter_feature_train = iter(
            feature_train.repeat().shuffle(100).apply(bucket).prefetch(
                buffer_size=5))
        # iter_feature_unsupervise = iter(feature_unsupervise.repeat().shuffle(100).apply(bucket).prefetch(buffer_size=5))
        # iter_feature_train = iter(feature_train.repeat().shuffle(100).padded_batch(args.batch_size,
        #         ((), [None, args.dim_input])).prefetch(buffer_size=5))
        iter_feature_unsupervise = iter(
            feature_unsupervise.repeat().shuffle(100).padded_batch(
                args.batch_size,
                ((), [None, args.dim_input])).prefetch(buffer_size=5))
        # feature_dev = feature_dev.apply(bucket).prefetch(buffer_size=5)
        feature_dev = feature_dev.padded_batch(args.batch_size,
                                               ((), [None, args.dim_input]))

        dataset_text = TextDataSet(list_files=[args.dirs.lm.data],
                                   args=args,
                                   _shuffle=True)
        tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                      (tf.TensorShape([None])))
        iter_text = iter(tfdata_train.cache().repeat().shuffle(1000).map(
            lambda x: x[:args.model.D.max_label_len]).padded_batch(
                args.text_batch_size,
                ([args.model.D.max_label_len])).prefetch(buffer_size=5))

    # create model paremeters
    encoder = Encoder(args)
    decoder = Decoder(args)
    D = CLM(args)
    encoder.summary()
    decoder.summary()
    D.summary()
    optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9)
    optimizer_D = tf.keras.optimizers.Adam(0.0001, beta_1=0.5, beta_2=0.9)

    writer = tf.summary.create_file_writer(str(args.dir_log))
    ckpt_G = tf.train.Checkpoint(encoder=encoder, decoder=decoder)
    ckpt_manager = tf.train.CheckpointManager(ckpt_G,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    step = 0

    if args.dirs.checkpoint_G:
        _ckpt_manager = tf.train.CheckpointManager(ckpt_G,
                                                   args.dirs.checkpoint_G,
                                                   max_to_keep=1)
        ckpt_G.restore(_ckpt_manager.latest_checkpoint)
        print('checkpoint_G {} restored!!'.format(
            _ckpt_manager.latest_checkpoint))
        # cer = evaluate(feature_dev, dataset_dev, args.data.dev_size, encoder, decoder)
        # with writer.as_default():
        #     tf.summary.scalar("performance/cer", cer, step=step)

    start_time = datetime.now()
    num_processed = 0
    while step < 99999999:
        start = time()

        # supervise training
        uttids, x = next(iter_feature_train)
        trans = dataset_train.get_attrs('trans', uttids.numpy())
        loss_supervise = train_CTC_supervised(x, trans, encoder, decoder,
                                              optimizer)

        # unsupervise training
        text = next(iter_text)
        _, un_x = next(iter_feature_unsupervise)
        # loss_G = train_G(un_x, encoder, decoder, D, optimizer, args.model.D.max_label_len)
        loss_G = train_G(un_x, encoder, decoder, D, optimizer,
                         args.model.D.max_label_len)
        loss_D = train_D(un_x, text, encoder, decoder, D, optimizer_D,
                         args.lambda_gp, args.model.D.max_label_len)

        num_processed += len(un_x)
        progress = num_processed / args.data.untrain_size

        if step % 10 == 0:
            print(
                'loss_supervise: {:.3f}\tloss_G: {:.3f}\tloss_D: {:.3f}\tbatch: {}\tused: {:.3f}\t {:.3f}% step: {}'
                .format(loss_supervise, loss_G, loss_D, un_x.shape,
                        time() - start, progress * 100, step))
            with writer.as_default():
                tf.summary.scalar("costs/loss_supervise",
                                  loss_supervise,
                                  step=step)
        if step % args.dev_step == args.dev_step - 1:
            cer = evaluate(feature_dev, dataset_dev, args.data.dev_size,
                           encoder, decoder)
            with writer.as_default():
                tf.summary.scalar("performance/cer", cer, step=step)
        if step % args.decode_step == 0:
            monitor(dataset_dev[0], encoder, decoder)
        if step % args.save_step == 0:
            save_path = ckpt_manager.save(step)
            print('save model {}'.format(save_path))

        step += 1

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Beispiel #7
0
def train(Model):
    # create dataset and dataloader
    dataset_train = TextDataSet(list_files=[args.dirs.train.data],
                                args=args,
                                _shuffle=True)
    dataset_dev = TextDataSet(list_files=[args.dirs.dev.data],
                              args=args,
                              _shuffle=False)

    args.data.train_size = len(dataset_train)
    args.data.dev_size = len(dataset_dev)

    tfdata_train = tf.data.Dataset.from_generator(dataset_train, (tf.int32),
                                                  (tf.TensorShape([None])))
    tfdata_dev = tf.data.Dataset.from_generator(dataset_dev, (tf.int32),
                                                (tf.TensorShape([None])))

    tfdata_train = tfdata_train.cache().repeat().shuffle(500).padded_batch(
        args.batch_size, ([None])).prefetch(buffer_size=5)
    tfdata_dev = tfdata_dev.padded_batch(args.batch_size, ([None]))

    # build optimizer
    warmup = warmup_exponential_decay(warmup_steps=args.opti.warmup_steps,
                                      peak=args.opti.peak,
                                      decay_steps=args.opti.decay_steps)
    optimizer = tf.keras.optimizers.Adam(warmup,
                                         beta_1=0.9,
                                         beta_2=0.98,
                                         epsilon=1e-9)

    # create model paremeters
    model = Model(args)
    model.summary()

    # save & reload
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt,
                                              args.dir_checkpoint,
                                              max_to_keep=20)
    if args.dirs.restore:
        latest_checkpoint = tf.train.CheckpointManager(
            ckpt, args.dirs.restore, max_to_keep=1).latest_checkpoint
        ckpt.restore(latest_checkpoint)
        print('{} restored!!'.format(latest_checkpoint))

    start_time = datetime.now()
    get_data_time = 0
    num_processed = 0
    progress = 0

    for global_step, batch in enumerate(tfdata_train):
        x, y = batch

        run_model_time = time()
        loss = train_step(x, y, model, optimizer)

        num_processed += len(x)
        get_data_time = run_model_time - get_data_time
        run_model_time = time() - run_model_time

        progress = num_processed / args.data.train_size
        if global_step % 10 == 0:
            print(
                'loss: {:.5f}\t batch: {} lr:{:.6f} time: {:.2f}|{:.2f} s {:.3f}% step: {}'
                .format(loss, x.shape,
                        warmup(global_step * 1.0).numpy(), get_data_time,
                        run_model_time, progress * 100.0, global_step))
        get_data_time = time()

        if global_step % args.dev_step == 0:
            evaluation(tfdata_dev, model)
        # if global_step % args.decode_step == 0:
        #     decode(model)
        if global_step % args.save_step == 0:
            ckpt_manager.save()

    print('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Beispiel #8
0
def train():
    dataset_text = TextDataSet(list_files=[args.dirs.text.data], args=args, _shuffle=True)
    tfdata_train = tf.data.Dataset.from_generator(
        dataset_text, (tf.int32), (tf.TensorShape([None])))
    iter_text = tfdata_train.cache().repeat().shuffle(10000).\
        padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()

    dataset_text_supervise = TextDataSet(list_files=[args.dirs.text.supervise], args=args, _shuffle=True)
    tfdata_supervise = tf.data.Dataset.from_generator(
        dataset_text_supervise, (tf.int32), (tf.TensorShape([None])))
    iter_supervise = tfdata_supervise.cache().repeat().shuffle(100).\
        padded_batch(args.num_supervised, ([args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()

    dataset_text_dev = TextDataSet(list_files=[args.dirs.text.dev], args=args, _shuffle=False)
    tfdata_dev = tf.data.Dataset.from_generator(
        dataset_text_dev, (tf.int32), (tf.TensorShape([None])))
    tfdata_dev = tfdata_dev.cache().\
        padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\
        make_initializable_iterator()
    iter_text_dev = tfdata_dev.get_next()

    tensor_global_step = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False)

    G = Generator(tensor_global_step,
                  hidden=args.model.hidden_size,
                  num_blocks=args.model.num_blocks,
                  training=True,
                  args=args)
    G_infer = Generator(tensor_global_step,
                        hidden=args.model.hidden_size,
                        num_blocks=args.model.num_blocks,
                        training=False,
                        args=args)
    vars_G = G.trainable_variables

    D = args.Model_D(tensor_global_step1,
                     training=True,
                     name='discriminator',
                     args=args)

    gan = Conditional_GAN([tensor_global_step0, tensor_global_step1], G, D,
                     batch=None, unbatch=None, name='text_gan', args=args)

    size_variables()

    start_time = datetime.now()
    saver_G = tf.train.Saver(vars_G, max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        if args.dirs.checkpoint_G:
            saver_G.restore(sess, args.dirs.checkpoint_G)
            print('revocer G from', args.dirs.checkpoint_G)

        # np.random.seed(0)
        text_supervise = sess.run(iter_supervise)
        text_len_supervise = get_batch_length(text_supervise)
        feature_supervise, feature_len_supervise = int2vector(text_supervise, text_len_supervise, hidden_size=args.model.dim_input , uprate=args.uprate)
        feature_supervise += np.random.randn(*feature_supervise.shape)/args.noise
        batch_time = time()
        global_step = 0

        # for _ in range(100):
        #     np.random.seed(1)
        #     text_G = sess.run(iter_text)
        #     text_lens_G = get_batch_length(text_G)
        #     feature_text, text_lens_G = int2vector(text_G, text_lens_G, hidden_size=args.model.dim_input , uprate=args.uprate)
        #     feature_text += np.random.randn(*feature_text.shape)/args.noise
        #     loss_G, loss_G_supervise, _ = sess.run(gan.list_train_G,
        #          feed_dict={gan.list_G_pl[0]:feature_text,
        #                     gan.list_G_pl[1]:text_lens_G,
        #                     gan.list_G_pl[2]:feature_supervise,
        #                     gan.list_G_pl[3]:feature_len_supervise,
        #                     gan.list_G_pl[4]:text_supervise,
        #                     gan.list_G_pl[5]:text_len_supervise})
        #     saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=0, write_meta_graph=True)

        while global_step < 99999999:

            # global_step, lr_G, lr_D = sess.run([tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D])
            global_step, lr_G = sess.run([tensor_global_step0, G.learning_rate])

            # text_supervise = sess.run(iter_supervise)
            # text_len_supervise = get_batch_length(text_supervise)
            # feature_supervise, feature_len_supervise = int2vector(text_supervise, text_len_supervise, hidden_size=args.model.dim_input , uprate=args.uprate)
            # feature_supervise += np.random.randn(*feature_supervise.shape)/args.noise

            # supervise
            # for _ in range(1):
            #     loss_G_supervise, _ = sess.run(G.run_list,
            #                          feed_dict={G.list_pl[0]:feature_text_supervise,
            #                                     G.list_pl[1]:text_len_supervise,
            #                                     G.list_pl[2]:text_supervise,
            #                                     G.list_pl[3]:text_len_supervise})

            # generator input
            text_G = sess.run(iter_text)
            text_lens_G = get_batch_length(text_G)
            feature_text, text_lens_G = int2vector(text_G, text_lens_G, hidden_size=args.model.dim_input , uprate=args.uprate)
            feature_text += np.random.randn(*feature_text.shape)/args.noise
            loss_G, loss_G_supervise, _ = sess.run(gan.list_train_G,
                 feed_dict={gan.list_G_pl[0]:feature_text,
                            gan.list_G_pl[1]:text_lens_G,
                            gan.list_G_pl[2]:feature_supervise,
                            gan.list_G_pl[3]:feature_len_supervise,
                            gan.list_G_pl[4]:text_supervise,
                            gan.list_G_pl[5]:text_len_supervise})
            # loss_G = loss_G_supervise = 0

            # discriminator input
            for _ in range(3):
                # np.random.seed(2)
                text_G = sess.run(iter_text)
                text_lens_G = get_batch_length(text_G)
                feature_G, feature_lens_G = int2vector(text_G, text_lens_G, hidden_size=args.model.dim_input, uprate=args.uprate)
                feature_G += np.random.randn(*feature_G.shape)/args.noise

                text_D = sess.run(iter_text)
                text_lens_D = get_batch_length(text_D)
                shape_text = text_D.shape
                loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D,
                        feed_dict={gan.list_D_pl[0]:text_D,
                                   gan.list_D_pl[1]:text_lens_D,
                                   gan.list_G_pl[0]:feature_G,
                                   gan.list_G_pl[1]:feature_lens_G})
            # loss_D_res = loss_D_text = loss_gp = 0
            # shape_text = [0,0,0]

            # loss_D_res = - loss_G
            # loss_G = loss_G_supervise = 0.0
            # loss_D = loss_D_text = loss_gp = 0.0
            # train
            # if global_step % 5 == 0:
            # for _ in range(2):
            #     loss_supervise, shape_batch, _, _ = sess.run(G.list_run)
            # loss_G_supervise = 0
            used_time = time()-batch_time
            batch_time = time()

            if global_step % 10 == 0:
                # print('loss_G: {:.2f} loss_G_supervise: {:.2f} loss_D_res: {:.2f} loss_D_text: {:.2f} step: {}'.format(
                #        loss_G, loss_G_supervise, loss_D_res, loss_D_text, global_step))
                print('loss_G_supervise: {:.2f} loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}\tlr:{:.1e} {:.2f}s step: {}'.format(
                       loss_G_supervise, loss_D_res, loss_D_text, loss_gp, shape_text, lr_G, used_time, global_step))
                # summary.summary_scalar('loss_G', loss_G, global_step)
                # summary.summary_scalar('loss_D', loss_D, global_step)
                # summary.summary_scalar('lr_G', lr_G, global_step)
                # summary.summary_scalar('lr_D', lr_D, global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=global_step, write_meta_graph=True)
                print('saved model as',  str(args.dir_checkpoint) + '/model-' + str(global_step))
            # if global_step % args.dev_step == args.dev_step - 1:
            if global_step % args.dev_step == 1:
                text_G_dev = dev(iter_text_dev, tfdata_dev, dataset_text_dev, sess, G_infer)

            # if global_step % args.decode_step == args.decode_step - 1:
            if global_step % args.decode_step == args.decode_step - 1:
                decode(text_G_dev, tfdata_dev, sess, G_infer)

    logging.info('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))