Exemple #1
0
def train():
    dataset_train = PinyinDataSet(list_files=[args.dirs.text.data],
                                  args=args,
                                  _shuffle=True)
    tfdata_train = tf.data.Dataset.from_generator(
        dataset_train, (tf.int32, tf.int32),
        (tf.TensorShape([None]), tf.TensorShape([None])))
    iter_train = tfdata_train.cache().repeat().shuffle(10000).\
        padded_batch(args.text_batch_size, ([args.max_label_len], [args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()

    dataset_supervise = PinyinDataSet(list_files=[args.dirs.text.supervise],
                                      args=args,
                                      _shuffle=True)
    tfdata_supervise = tf.data.Dataset.from_generator(
        dataset_supervise, (tf.int32, tf.int32),
        (tf.TensorShape([None]), tf.TensorShape([None])))
    iter_supervise = tfdata_supervise.cache().repeat().shuffle(100).\
        padded_batch(args.num_supervised, ([args.max_label_len], [args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()

    dataset_dev = PinyinDataSet(list_files=[args.dirs.text.dev],
                                args=args,
                                _shuffle=False)
    tfdata_dev = tf.data.Dataset.from_generator(
        dataset_dev, (tf.int32, tf.int32),
        (tf.TensorShape([None]), tf.TensorShape([None])))
    tfdata_dev = tfdata_dev.cache().\
        padded_batch(args.text_batch_size, ([args.max_label_len], [args.max_label_len])).prefetch(buffer_size=5).\
        make_initializable_iterator()
    iter_dev = tfdata_dev.get_next()

    tensor_global_step = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False)

    G = Generator(tensor_global_step, training=True, args=args)
    G_infer = Generator(tensor_global_step, training=False, args=args)
    vars_G = G.trainable_variables

    D = args.Model_D(tensor_global_step1,
                     training=True,
                     name='discriminator',
                     args=args)

    gan = Conditional_GAN([tensor_global_step0, tensor_global_step1],
                          G,
                          D,
                          batch=None,
                          unbatch=None,
                          name='text_gan',
                          args=args)

    size_variables()

    start_time = datetime.now()
    saver_G = tf.train.Saver(vars_G, max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        if args.dirs.checkpoint_G:
            saver_G.restore(sess, args.dirs.checkpoint_G)
            print('revocer G from', args.dirs.checkpoint_G)

        # np.random.seed(0)
        pinyin_supervise, text_supervise = sess.run(iter_supervise)
        text_len_supervise = pinyin_len_supervise = get_batch_length(
            pinyin_supervise)
        feature_supervise, feature_len_supervise = int2vector(
            pinyin_supervise,
            pinyin_len_supervise,
            hidden_size=args.model.dim_input,
            uprate=args.uprate)
        feature_supervise += np.random.randn(
            *feature_supervise.shape) / args.noise

        # supervise
        for _ in range(500):
            sess.run(G.run_list,
                     feed_dict={
                         G.list_pl[0]: feature_supervise,
                         G.list_pl[1]: feature_len_supervise,
                         G.list_pl[2]: text_supervise,
                         G.list_pl[3]: text_len_supervise
                     })

        batch_time = time()
        global_step = 0

        while global_step < 99999999:

            # global_step, lr_G, lr_D = sess.run([tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D])
            global_step, lr_G = sess.run(
                [tensor_global_step0, G.learning_rate])

            pinyin_supervise, text_supervise = sess.run(iter_supervise)
            pinyin_len_supervise = get_batch_length(pinyin_supervise)
            feature_supervise, feature_len_supervise = int2vector(
                pinyin_supervise,
                pinyin_len_supervise,
                hidden_size=args.model.dim_input,
                uprate=args.uprate)
            feature_supervise += np.random.randn(
                *feature_supervise.shape) / args.noise

            # supervise
            # for _ in range(1):
            #     loss_G_supervise, _ = sess.run(G.run_list,
            #                          feed_dict={G.list_pl[0]:feature_text_supervise,
            #                                     G.list_pl[1]:text_len_supervise,
            #                                     G.list_pl[2]:text_supervise,
            #                                     G.list_pl[3]:text_len_supervise})

            # generator input
            pinyin_G, text_G = sess.run(iter_train)
            pinyin_lens_G = get_batch_length(pinyin_G)
            feature_G, lens_G = int2vector(pinyin_G,
                                           pinyin_lens_G,
                                           hidden_size=args.model.dim_input,
                                           uprate=args.uprate)
            feature_G += np.random.randn(*feature_G.shape) / args.noise
            loss_G, loss_G_supervise, _ = sess.run(gan.list_train_G,
                                                   feed_dict={
                                                       gan.list_G_pl[0]:
                                                       feature_G,
                                                       gan.list_G_pl[1]:
                                                       lens_G,
                                                       gan.list_G_pl[2]:
                                                       feature_supervise,
                                                       gan.list_G_pl[3]:
                                                       feature_len_supervise,
                                                       gan.list_G_pl[4]:
                                                       text_supervise,
                                                       gan.list_G_pl[5]:
                                                       pinyin_len_supervise
                                                   })
            # loss_G = loss_G_supervise = 0

            # discriminator input
            # for _ in range(5):
            #     # np.random.seed(2)
            #     pinyin_G, text_G = sess.run(iter_train)
            #     pinyin_lens_G = get_batch_length(pinyin_G)
            #     feature_G, feature_lens_G = int2vector(pinyin_G, pinyin_lens_G, hidden_size=args.model.dim_input, uprate=args.uprate)
            #     feature_G += np.random.randn(*feature_G.shape)/args.noise
            #
            #     pinyin_D, text_D = sess.run(iter_train)
            #     text_lens_D = get_batch_length(text_D)
            #     shape_text = text_D.shape
            #     loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D,
            #             feed_dict={gan.list_D_pl[0]:text_D,
            #                        gan.list_D_pl[1]:text_lens_D,
            #                        gan.list_G_pl[0]:feature_G,
            #                        gan.list_G_pl[1]:feature_lens_G})
            loss_D_res = loss_D_text = loss_gp = 0
            shape_text = [0, 0, 0]

            # loss_D_res = - loss_G
            # loss_G = loss_G_supervise = 0.0
            # loss_D = loss_D_text = loss_gp = 0.0
            # train
            # if global_step % 5 == 0:
            # for _ in range(2):
            #     loss_supervise, shape_batch, _, _ = sess.run(G.list_run)
            # loss_G_supervise = 0
            used_time = time() - batch_time
            batch_time = time()

            if global_step % 10 == 0:
                # print('loss_G: {:.2f} loss_G_supervise: {:.2f} loss_D_res: {:.2f} loss_D_text: {:.2f} step: {}'.format(
                #        loss_G, loss_G_supervise, loss_D_res, loss_D_text, global_step))
                print(
                    'loss_G_supervise: {:.2f} loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}\tlr:{:.1e} {:.2f}s step: {}'
                    .format(loss_G_supervise, loss_D_res, loss_D_text, loss_gp,
                            shape_text, lr_G, used_time, global_step))
                # summary.summary_scalar('loss_G', loss_G, global_step)
                # summary.summary_scalar('loss_D', loss_D, global_step)
                # summary.summary_scalar('lr_G', lr_G, global_step)
                # summary.summary_scalar('lr_D', lr_D, global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess),
                           str(args.dir_checkpoint / 'model'),
                           global_step=global_step,
                           write_meta_graph=True)
                print('saved model as',
                      str(args.dir_checkpoint) + '/model-' + str(global_step))
            # if global_step % args.dev_step == args.dev_step - 1:
            if global_step % args.dev_step == 1:
                pinyin_G_dev, text_G_dev = dev(iter_dev, tfdata_dev,
                                               dataset_dev, sess, G_infer)

            # if global_step % args.decode_step == args.decode_step - 1:
            if global_step % args.decode_step == args.decode_step - 1:
                decode(pinyin_G_dev, text_G_dev, sess, G_infer)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Exemple #2
0
def train():
    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFReader(args.dirs.train.tfdata, args=args)
    batch_train = dataReader_train.fentch_batch_bucket()
    dataReader_untrain = TFReader(args.dirs.untrain.tfdata, args=args)
    batch_untrain = dataReader_untrain.fentch_batch(args.batch_size)
    # batch_untrain = dataReader_untrain.fentch_batch_bucket()
    args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata)
    args.data.untrain_size = TFData.read_tfdata_info(
        args.dirs.untrain.tfdata)['size_dataset']

    dataset_text = TextDataSet(list_files=[args.dirs.text.data],
                               args=args,
                               _shuffle=True)
    tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                  (tf.TensorShape([None])))
    iter_text = tfdata_train.cache().repeat().shuffle(1000).\
        padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()

    feat, label = readTFRecord(args.dirs.dev.tfdata,
                               args,
                               _shuffle=False,
                               transform=True)
    dataloader_dev = ASRDataLoader(args.dataset_dev,
                                   args,
                                   feat,
                                   label,
                                   batch_size=args.batch_size,
                                   num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()
    tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False)

    G = args.Model(tensor_global_step,
                   encoder=args.model.encoder.type,
                   decoder=args.model.decoder.type,
                   batch=batch_train,
                   training=True,
                   args=args)

    G_infer = args.Model(tensor_global_step,
                         encoder=args.model.encoder.type,
                         decoder=args.model.decoder.type,
                         training=False,
                         args=args)
    vars_G = G.trainable_variables

    D = args.Model_D(tensor_global_step1,
                     training=True,
                     name='discriminator',
                     args=args)

    gan = args.GAN([tensor_global_step0, tensor_global_step1],
                   G,
                   D,
                   batch=batch_train,
                   unbatch=batch_untrain,
                   name='GAN',
                   args=args)

    size_variables()

    start_time = datetime.now()
    saver_G = tf.train.Saver(vars_G, max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        dataloader_dev.sess = sess
        if args.dirs.checkpoint_G:
            saver_G.restore(sess, args.dirs.checkpoint_G)

        for i in range(2000):
            ctc_loss, *_ = sess.run(G.list_run)
            if i % 400:
                print('ctc_loss: {:.2f}'.format(ctc_loss))

        batch_time = time()
        num_processed = 0
        num_processed_unbatch = 0
        progress = 0
        while progress < args.num_epochs:

            global_step, lr_G, lr_D = sess.run([
                tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D
            ])

            # untrain
            text = sess.run(iter_text)
            text_lens = get_batch_length(text)
            shape_text = text.shape
            # loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(gan.list_train_D,
            #                                               feed_dict={gan.list_pl[0]:text,
            #                                                          gan.list_pl[1]:text_lens})
            loss_D = loss_D_res = loss_D_text = loss_gp = 0
            (loss_G, loss_G_supervise, _), (shape_feature, shape_unfeature) = \
                sess.run([gan.list_train_G, gan.list_feature_shape])

            num_processed += shape_feature[0]
            num_processed_unbatch += shape_unfeature[0]
            used_time = time() - batch_time
            batch_time = time()
            progress = num_processed / args.data.train_size
            progress_unbatch = num_processed_unbatch / args.data.untrain_size

            if global_step % 20 == 0:
                print(
                    'loss_supervise: {:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\tbatch: {}|{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.1f}%|{:.1f}% step: {}'
                    .format(loss_G_supervise, loss_D_res, loss_D_text, loss_gp,
                            shape_feature, shape_unfeature, shape_text, lr_G,
                            lr_D, used_time, progress * 100.0,
                            progress_unbatch * 100.0, global_step))
                # summary.summary_scalar('loss_G', loss_G, global_step)
                # summary.summary_scalar('loss_D', loss_D, global_step)
                # summary.summary_scalar('lr_G', lr_G, global_step)
                # summary.summary_scalar('lr_D', lr_D, global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess),
                           str(args.dir_checkpoint / 'model'),
                           global_step=global_step,
                           write_meta_graph=True)

            # if global_step % args.dev_step == args.dev_step - 1:
            if global_step % args.dev_step == 0:
                cer, wer = dev(step=global_step,
                               dataloader=dataloader_dev,
                               model=G_infer,
                               sess=sess,
                               unit=args.data.unit,
                               idx2token=args.idx2token,
                               token2idx=args.token2idx)
                # summary.summary_scalar('dev_cer', cer, global_step)
                # summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
                decode_test(step=global_step,
                            sample=args.dataset_test[10],
                            model=G_infer,
                            sess=sess,
                            unit=args.data.unit,
                            idx2token=args.idx2token,
                            token2idx=args.token2idx)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Exemple #3
0
def train():
    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFDataReader(args.dirs.train.tfdata,
                                    args=args,
                                    _shuffle=True,
                                    transform=True)
    dataReader_dev = TFDataReader(args.dirs.dev.tfdata,
                                  args=args,
                                  _shuffle=False,
                                  transform=True)

    batch_train = dataReader_train.fentch_batch_bucket()
    dataloader_dev = ASRDataLoader(args.dataset_dev,
                                   args,
                                   dataReader_dev.feat,
                                   dataReader_dev.label,
                                   batch_size=args.batch_size,
                                   num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()

    model = args.Model(tensor_global_step,
                       encoder=args.model.encoder.type,
                       decoder=args.model.decoder.type,
                       batch=batch_train,
                       training=True,
                       args=args)

    model_infer = args.Model(tensor_global_step,
                             encoder=args.model.encoder.type,
                             decoder=args.model.decoder.type,
                             training=False,
                             args=args)

    size_variables()
    start_time = datetime.now()

    saver = tf.train.Saver(max_to_keep=15)
    if args.dirs.lm_checkpoint:
        from tfTools.checkpointTools import list_variables

        list_lm_vars_pretrained = list_variables(args.dirs.lm_checkpoint)
        list_lm_vars = model.decoder.lm.variables

        dict_lm_vars = {}
        for var in list_lm_vars:
            if 'embedding' in var.name:
                for var_pre in list_lm_vars_pretrained:
                    if 'embedding' in var_pre[0]:
                        break
            else:
                name = var.name.split(model.decoder.lm.name)[1].split(':')[0]
                for var_pre in list_lm_vars_pretrained:
                    if name in var_pre[0]:
                        break
            # 'var_name_in_checkpoint': var_in_graph
            dict_lm_vars[var_pre[0]] = var

        saver_lm = tf.train.Saver(dict_lm_vars)

    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        # for i in range(100):
        #     batch = sess.run(batch_train)
        #     import pdb; pdb.set_trace()
        #     print(batch[0].shape)
        _, labels, _, len_labels = sess.run(batch_train)

        if args.dirs.checkpoint:
            saver.restore(sess, args.dirs.checkpoint)
        elif args.dirs.lm_checkpoint:
            lm_checkpoint = tf.train.latest_checkpoint(args.dirs.lm_checkpoint)
            saver_lm.restore(sess, lm_checkpoint)

        dataloader_dev.sess = sess

        batch_time = time()
        num_processed = 0
        progress = 0
        while progress < args.num_epochs:
            global_step, lr = sess.run(
                [tensor_global_step, model.learning_rate])
            loss, shape_batch, _, debug = sess.run(model.list_run)
            num_processed += shape_batch[0]
            used_time = time() - batch_time
            batch_time = time()
            progress = num_processed / args.data.train_size

            if global_step % 50 == 0:
                print(
                    'loss: {:.3f}\tbatch: {} lr:{:.6f} time:{:.2f}s {:.3f}% step: {}'
                    .format(loss, shape_batch, lr, used_time, progress * 100.0,
                            global_step))
                summary.summary_scalar('loss', loss, global_step)
                summary.summary_scalar('lr', lr, global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess),
                           str(args.dir_checkpoint / 'model'),
                           global_step=global_step,
                           write_meta_graph=True)
                print('saved model in',
                      str(args.dir_checkpoint) + '/model-' + str(global_step))

            if global_step % args.dev_step == args.dev_step - 1:
                cer, wer = dev(step=global_step,
                               dataloader=dataloader_dev,
                               model=model_infer,
                               sess=sess,
                               unit=args.data.unit,
                               token2idx=args.token2idx,
                               idx2token=args.idx2token)
                summary.summary_scalar('dev_cer', cer, global_step)
                summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
                decode_test(step=global_step,
                            sample=args.dataset_test.uttid2sample(
                                args.sample_uttid),
                            model=model_infer,
                            sess=sess,
                            unit=args.data.unit,
                            idx2token=args.idx2token,
                            token2idx=args.token2idx)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
def train():
    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args)
    batch_train = dataReader_train.fentch_multi_batch_bucket()

    dataReader_dev = TFDataReader(args.dirs.dev.tfdata,
                                  args=args,
                                  _shuffle=False,
                                  transform=True)
    dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev,
                                          args,
                                          dataReader_dev.feat,
                                          dataReader_dev.phone,
                                          dataReader_dev.label,
                                          batch_size=args.batch_size,
                                          num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()

    G = args.Model(tensor_global_step,
                   encoder=args.model.encoder.type,
                   encoder2=args.model.encoder2.type,
                   decoder=args.model.decoder.type,
                   batch=batch_train,
                   training=True,
                   args=args)

    G_infer = args.Model(tensor_global_step,
                         encoder=args.model.encoder.type,
                         encoder2=args.model.encoder2.type,
                         decoder=args.model.decoder.type,
                         training=False,
                         args=args)
    vars_ASR = G.trainable_variables()
    vars_spiker = G.trainable_variables(G.name + '/spiker')

    size_variables()

    start_time = datetime.now()
    saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=30)
    saver_S = tf.train.Saver(vars_spiker, max_to_keep=30)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))
    step_bias = 0

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        dataloader_dev.sess = sess
        if args.dirs.checkpoint_G:
            saver_ASR.restore(sess, args.dirs.checkpoint_G)
            step_bias = int(args.dirs.checkpoint_G.split('-')[-1])
        if args.dirs.checkpoint_S:
            saver_S.restore(sess, args.dirs.checkpoint_S)

        batch_time = time()
        num_processed = 0
        progress = 0
        while progress < args.num_epochs:

            # supervised training
            global_step, lr = sess.run([tensor_global_step, G.learning_rate])
            global_step += step_bias
            loss_G, shape_batch, _, (ctc_loss, ce_loss,
                                     *_) = sess.run(G.list_run)
            num_processed += shape_batch[0]
            used_time = time() - batch_time
            batch_time = time()
            progress = num_processed / args.data.train_size

            if global_step % 40 == 0:
                print(
                    'ctc_loss: {:.2f},  ce_loss: {:.2f} batch: {} lr:{:.1e} {:.2f}s {:.3f}% step: {}'
                    .format(np.mean(ctc_loss), np.mean(ce_loss), shape_batch,
                            lr, used_time, progress * 100, global_step))

            if global_step % args.save_step == args.save_step - 1:
                saver_ASR.save(get_session(sess),
                               str(args.dir_checkpoint / 'model'),
                               global_step=global_step)
                print('saved ASR model in',
                      str(args.dir_checkpoint) + '/model-' + str(global_step))
                saver_S.save(get_session(sess),
                             str(args.dir_checkpoint / 'model_S'),
                             global_step=global_step)
                print(
                    'saved Spiker model in',
                    str(args.dir_checkpoint) + '/model_S-' + str(global_step))

            if global_step % args.dev_step == args.dev_step - 1:
                # if global_step % args.dev_step == 0:
                per, cer, wer = dev(step=global_step,
                                    dataloader=dataloader_dev,
                                    model=G_infer,
                                    sess=sess,
                                    unit=args.data.unit,
                                    args=args)
                summary.summary_scalar('dev_per', per, global_step)
                summary.summary_scalar('dev_cer', cer, global_step)
                summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
                # if True:
                decode_test(step=global_step,
                            sample=args.dataset_test.uttid2sample(
                                args.sample_uttid),
                            model=G_infer,
                            sess=sess,
                            unit=args.data.unit,
                            args=args)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
def train_gan():
    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFDataReader(args.dirs.train.tfdata, args=args)
    batch_train = dataReader_train.fentch_multi_batch_bucket()
    dataReader_untrain = TFDataReader(args.dirs.untrain.tfdata, args=args)
    batch_untrain = dataReader_untrain.fentch_multi_batch(args.batch_size)
    args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata)
    args.data.untrain_size = TFDataReader.read_tfdata_info(
        args.dirs.untrain.tfdata)['size_dataset']

    dataset_text = TextDataSet(list_files=[args.dirs.text.data],
                               args=args,
                               _shuffle=True)
    tfdata_train = tf.data.Dataset.from_generator(dataset_text, (tf.int32),
                                                  (tf.TensorShape([None])))
    iter_text = tfdata_train.cache().repeat().shuffle(1000).\
        padded_batch(args.text_batch_size, ([args.max_label_len])).prefetch(buffer_size=5).\
        make_one_shot_iterator().get_next()
    dataReader_dev = TFDataReader(args.dirs.dev.tfdata,
                                  args=args,
                                  _shuffle=False,
                                  transform=True)
    dataloader_dev = ASR_Multi_DataLoader(args.dataset_dev,
                                          args,
                                          dataReader_dev.feat,
                                          dataReader_dev.phone,
                                          dataReader_dev.label,
                                          batch_size=args.batch_size,
                                          num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()
    tensor_global_step0 = tf.Variable(0, dtype=tf.int32, trainable=False)
    tensor_global_step1 = tf.Variable(0, dtype=tf.int32, trainable=False)

    G = args.Model(tensor_global_step,
                   encoder=args.model.encoder.type,
                   encoder2=args.model.encoder2.type,
                   decoder=args.model.decoder.type,
                   batch=batch_train,
                   training=True,
                   args=args)

    G_infer = args.Model(tensor_global_step,
                         encoder=args.model.encoder.type,
                         encoder2=args.model.encoder2.type,
                         decoder=args.model.decoder.type,
                         training=False,
                         args=args)
    vars_ASR = G.trainable_variables()
    # vars_G_ocd = G.trainable_variables('Ectc_Docd/ocd_decoder')

    D = args.Model_D(tensor_global_step1,
                     training=True,
                     name='discriminator',
                     args=args)

    gan = args.GAN([tensor_global_step0, tensor_global_step1],
                   G,
                   D,
                   batch=batch_train,
                   unbatch=batch_untrain,
                   name='GAN',
                   args=args)

    size_variables()

    start_time = datetime.now()
    saver_ASR = tf.train.Saver(vars_ASR, max_to_keep=10)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        dataloader_dev.sess = sess
        if args.dirs.checkpoint_G:
            saver_ASR.restore(sess, args.dirs.checkpoint_G)

        batch_time = time()
        num_processed = 0
        num_processed_unbatch = 0
        progress = 0
        while progress < args.num_epochs:

            # semi_supervise
            global_step, lr_G, lr_D = sess.run([
                tensor_global_step0, gan.learning_rate_G, gan.learning_rate_D
            ])

            for _ in range(3):
                text = sess.run(iter_text)
                text_lens = get_batch_length(text)
                shape_text = text.shape
                loss_D, loss_D_res, loss_D_text, loss_gp, _ = sess.run(
                    gan.list_train_D,
                    feed_dict={
                        gan.list_pl[0]: text,
                        gan.list_pl[1]: text_lens
                    })
                # loss_D=loss_D_res=loss_D_text=loss_gp=0
            (loss_G, ctc_loss, ce_loss, _), (shape_batch, shape_unbatch) = \
                sess.run([gan.list_train_G, gan.list_feature_shape])

            num_processed += shape_batch[0]
            # num_processed_unbatch += shape_unbatch[0]
            used_time = time() - batch_time
            batch_time = time()
            progress = num_processed / args.data.train_size
            progress_unbatch = num_processed_unbatch / args.data.untrain_size

            if global_step % 40 == 0:
                print('ctc|ce loss: {:.2f}|{:.2f}, loss res|real|gp: {:.2f}|{:.2f}|{:.2f}\t{}|{}\tlr:{:.1e}|{:.1e} {:.2f}s {:.3f}% step: {}'.format(
                       np.mean(ctc_loss), np.mean(ce_loss), loss_D_res, loss_D_text, loss_gp, shape_batch, \
                       shape_unbatch, lr_G, lr_D, used_time, progress*100, global_step))
                summary.summary_scalar('ctc_loss', np.mean(ctc_loss),
                                       global_step)
                summary.summary_scalar('ce_loss', np.mean(ce_loss),
                                       global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver_ASR.save(get_session(sess),
                               str(args.dir_checkpoint / 'model_G'),
                               global_step=global_step,
                               write_meta_graph=True)
                print(
                    'saved G in',
                    str(args.dir_checkpoint) + '/model_G-' + str(global_step))
                # saver_G_en.save(get_session(sess), str(args.dir_checkpoint/'model_G_en'), global_step=global_step, write_meta_graph=True)
                # print('saved model in',  str(args.dir_checkpoint)+'/model_G_en-'+str(global_step))

            if global_step % args.dev_step == args.dev_step - 1:
                # if True:
                per, cer, wer = dev(step=global_step,
                                    dataloader=dataloader_dev,
                                    model=G_infer,
                                    sess=sess,
                                    unit=args.data.unit,
                                    args=args)
                summary.summary_scalar('dev_per', per, global_step)
                summary.summary_scalar('dev_cer', cer, global_step)
                summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
                # if True:
                decode_test(step=global_step,
                            sample=args.dataset_test.uttid2sample(
                                args.sample_uttid),
                            model=G_infer,
                            sess=sess,
                            unit=args.data.unit,
                            args=args)

    logging.info('training duration: {:.2f}h'.format(
        (datetime.now() - start_time).total_seconds() / 3600))
Exemple #6
0
def train():

    args.num_gpus = len(args.gpus.split(',')) - 1
    args.list_gpus = ['/gpu:{}'.format(i) for i in range(args.num_gpus)]

    # bucket
    if args.bucket_boundaries:
        args.list_bucket_boundaries = [int(i) for i in args.bucket_boundaries.split(',')]

    assert args.num_batch_tokens
    args.list_batch_size = ([int(args.num_batch_tokens / boundary) * args.num_gpus
            for boundary in (args.list_bucket_boundaries)] + [args.num_gpus])
    args.list_infer_batch_size = ([int(args.num_batch_tokens / boundary)
            for boundary in (args.list_bucket_boundaries)] + [1])
    args.batch_size *= args.num_gpus
    logging.info('\nbucket_boundaries: {} \nbatch_size: {}'.format(
        args.list_bucket_boundaries, args.list_batch_size))

    print('reading data form ', args.dirs.train.tfdata)
    dataReader_train = TFReader(args.dirs.train.tfdata, args=args)
    batch_train = dataReader_train.fentch_batch_bucket()
    dataReader_untrain = TFReader(args.dirs.untrain.tfdata, args=args)
    # batch_untrain = dataReader_untrain.fentch_batch(args.batch_size)
    batch_untrain = dataReader_untrain.fentch_batch_bucket()
    args.dirs.untrain.tfdata = Path(args.dirs.untrain.tfdata)
    args.data.untrain_size = TFData.read_tfdata_info(args.dirs.untrain.tfdata)['size_dataset']

    feat, label = readTFRecord(args.dirs.dev.tfdata, args,
                               _shuffle=False, transform=True)
    dataloader_dev = ASRDataLoader(args.dataset_dev, args, feat, label,
                                   batch_size=args.batch_size, num_loops=1)

    tensor_global_step = tf.train.get_or_create_global_step()

    # get dataset ngram
    ngram_py, total_num = read_ngram(args.EODM.top_k, args.dirs.text.ngram, args.token2idx, type='list')
    kernel, py = ngram2kernel(ngram_py, args.EODM.ngram, args.EODM.top_k, args.dim_output)

    G = args.Model(
        tensor_global_step,
        encoder=args.model.encoder.type,
        decoder=args.model.decoder.type,
        kernel=kernel,
        py=py,
        batch=batch_train,
        unbatch=batch_untrain,
        training=True,
        args=args)

    args.list_gpus = ['/gpu:{}'.format(args.num_gpus)]

    G_infer = args.Model(
        tensor_global_step,
        encoder=args.model.encoder.type,
        decoder=args.model.decoder.type,
        training=False,
        args=args)
    vars_G = G.variables()

    size_variables()

    start_time = datetime.now()
    saver_G = tf.train.Saver(vars_G, max_to_keep=1)
    saver = tf.train.Saver(max_to_keep=15)
    summary = Summary(str(args.dir_log))

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    config.log_device_placement = False
    with tf.train.MonitoredTrainingSession(config=config) as sess:
        dataloader_dev.sess = sess
        if args.dirs.checkpoint_G:
            saver_G.restore(sess, args.dirs.checkpoint_G)

        batch_time = time()
        num_processed = 0
        num_processed_unbatch = 0
        progress = 0
        progress_unbatch = 0
        loss_CTC = 0.0; shape_batch = [0,0,0]
        loss_EODM = 0.0; shape_unbatch=[0,0,0]
        while progress < args.num_epochs:

            global_step, lr = sess.run([tensor_global_step, G.learning_rate])

            if global_step % 2 == 0:
                loss_CTC, shape_batch, _ = sess.run(G.list_run)
                # loss_CTC = 0.0; shape_batch = [0,0,0]
            else:
                loss_EODM, shape_unbatch, _ = sess.run(G.list_run_EODM)
                # loss_EODM = 0.0; shape_unbatch=[0,0,0]

            num_processed += shape_batch[0]
            num_processed_unbatch += shape_unbatch[0]
            used_time = time()-batch_time
            batch_time = time()
            progress = num_processed/args.data.train_size
            progress_unbatch = num_processed_unbatch/args.data.untrain_size

            if global_step % 50 == 0:
                logging.info('loss: {:.2f}|{:.2f}\tbatch: {}|{} lr:{:.6f} time:{:.2f}s {:.2f}% {:.2f}% step: {}'.format(
                              loss_CTC, loss_EODM, shape_batch, shape_unbatch, lr, used_time, progress*100.0, progress_unbatch*100.0, global_step))
                # summary.summary_scalar('loss_G', loss_G, global_step)
                # summary.summary_scalar('loss_D', loss_D, global_step)
                # summary.summary_scalar('lr_G', lr_G, global_step)
                # summary.summary_scalar('lr_D', lr_D, global_step)

            if global_step % args.save_step == args.save_step - 1:
                saver.save(get_session(sess), str(args.dir_checkpoint/'model'), global_step=global_step, write_meta_graph=True)

            if global_step % args.dev_step == args.dev_step - 1:
            # if global_step % args.dev_step == 0:
                cer, wer = dev(
                    step=global_step,
                    dataloader=dataloader_dev,
                    model=G_infer,
                    sess=sess,
                    unit=args.data.unit,
                    idx2token=args.idx2token,
                    token2idx=args.token2idx)
                # summary.summary_scalar('dev_cer', cer, global_step)
                # summary.summary_scalar('dev_wer', wer, global_step)

            if global_step % args.decode_step == args.decode_step - 1:
            # if global_step:
                decode_test(
                    step=global_step,
                    sample=args.dataset_test[10],
                    model=G_infer,
                    sess=sess,
                    unit=args.data.unit,
                    idx2token=args.idx2token,
                    token2idx=args.token2idx)

    logging.info('training duration: {:.2f}h'.format((datetime.now()-start_time).total_seconds()/3600))