Example #1
0
def main(argv):
    session_config = tf.ConfigProto()
    session_config.gpu_options.allow_growth = True
    run_config = tf.estimator.RunConfig().replace(
        session_config=session_config)

    train_feeder = DataFeeder(['train_random_video.txt'] + [
        'train_random_video_04{}.txt'.format(n)
        for n in [13, 15, 16, 17, 18, 19, 20, 21]
    ])
    test_feeder = DataFeeder(['train_HRS_video.txt'], is_training=False)

    classifier = tf.estimator.Estimator(model_fn=my_model,
                                        model_dir='log8_moredata_cos',
                                        config=run_config,
                                        params=train_feeder)  #
    train_spec = tf.estimator.TrainSpec(
        input_fn=lambda: [train_feeder.tensor_data_generator(128, 100), None])
    test_spec = tf.estimator.EvalSpec(
        input_fn=lambda: [test_feeder.tensor_data_generator(128, 100), None],
        throttle_secs=600)
    tf.estimator.train_and_evaluate(classifier, train_spec, test_spec)
def train(log_dir, args):
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = './jenie_Processed/amused/'
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.model)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.linear_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Train!
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            restore_path = './Trained Weights/model.ckpt-lj'
            print(restore_path)
            saver.restore(sess, restore_path)
            log('[INFO] Resuming from checkpoint: %s ' % (restore_path))

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message)

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step))
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram, alignment = sess.run([
                        model.inputs[0], model.linear_outputs[0],
                        model.alignments[0]
                    ])
                    waveform = audio.inv_spectrogram(spectrogram.T)
                    audio.save_wav(
                        waveform,
                        os.path.join(log_dir, 'step-%d-audio.wav' % step))
                    plot.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, step=%d, loss=%.5f' %
                        (args.model, time_string(), step, loss))
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e)
            traceback.print_exc()
            coord.request_stop(e)
Example #3
0
def train(log_dir, input_path, checkpoint_path, is_restore):
    # Log the info
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model('tacotron', hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.linear_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # Bookkeeping:
    step = 0
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)
    # Train!
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if is_restore:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s' % (checkpoint_path)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint')
            else:
                log('Starting new training')

            feeder.start_in_session(sess)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_interval = time.time() - start_time

                message = 'Step %d, %.03f sec, loss=%.05f' % (step, loss,
                                                              time_interval)
                log(message)

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % hparams.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % hparams.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    input_seq, spectrogram, alignment = sess.run([
                        model.inputs[0], model.linear_outputs[0],
                        model.alignments[0]
                    ])
                    waveform = audio.inv_spectrogram(spectrogram.T)
                    audio.save_wav(
                        waveform,
                        os.path.join(log_dir, 'step-%d-audio.wav' % step))
                    plot.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, step=%d, loss=%.5f' %
                        ('tacotron', time_string(), step, loss))
                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            coord.request_stop(e)
Example #4
0
def train(logdir,args):

    # TODO:parse  ckpt,arguments,hparams
    checkpoint_path = os.path.join(logdir,'model.ckpt')
    input_path = args.data_dir
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from : %s ' % input_path)
    log('Using model : %s' %args.model)

    # TODO:set up datafeeder
    with tf.variable_scope('datafeeder') as scope:
        hp.data_length = None
        hp.initial_learning_rate = 0.0001
        hp.batch_size = 256
        hp.prime = True
        hp.stcmd = True
        feeder = DataFeeder(args=hp)
        log('num_sentences:'+str(len(feeder.wav_lst))) # 283600
        hp.input_vocab_size = len(feeder.pny_vocab)
        hp.final_output_dim = len(feeder.pny_vocab)
        hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size
        log('steps_per_epoch:' + str(hp.steps_per_epoch))  # 17725
        log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292
        hp.label_vocab_size = len(feeder.han_vocab)
        log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291

    # TODO:set up model
    global_step = tf.Variable(initial_value=0,name='global_step',trainable=False)
    valid_step = 0
    # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model,hp)
        model.build_graph()
        model.add_loss()
        model.add_optimizer(global_step=global_step,loss=model.mean_loss)
        # TODO: summary
        stats = add_stats(model=model)
        valid_stats = add_dev_stats(model)


    # TODO:Set up saver and Bookkeeping
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    acc_window = ValueWindow(100)
    valid_time_window = ValueWindow(100)
    valid_loss_window = ValueWindow(100)
    valid_acc_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=20)
    first_serving = True
    # TODO: train
    with tf.Session() as sess:

        log(hparams_debug_string(hp))
        try:
            # TODO: Set writer and initializer
            summary_writer = tf.summary.FileWriter(logdir + '/train', sess.graph)
            summary_writer_dev = tf.summary.FileWriter(logdir + '/dev')
            sess.run(tf.global_variables_initializer())

            # TODO: Restore
            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s ' % restore_path)
            else:
                log('Starting new training run ')

            step = 0
            # TODO: epochs steps batch
            for i in range(args.epochs):
                batch_data = feeder.get_lm_batch()
                log('Traning epoch '+ str(i)+':')
                for j in range(hp.steps_per_epoch):
                    input_batch, label_batch = next(batch_data)
                    feed_dict = {
                        model.x:input_batch,
                        model.y:label_batch,
                    }
                    # TODO: Run one step ~~~
                    start_time = time.time()
                    total_step,batch_loss,batch_acc,opt = sess.run([global_step, model.mean_loss,model.acc,model.optimize],feed_dict=feed_dict)
                    time_window.append(time.time() - start_time)
                    step = total_step

                    # TODO: Append loss
                    loss_window.append(batch_loss)
                    acc_window.append(batch_acc)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f,acc=%.05f, avg_acc=%.05f,  lr=%.07f]' % (
                        step, time_window.average, batch_loss, loss_window.average,batch_acc,acc_window.average,K.get_value(model.learning_rate))
                    log(message)

                    # TODO: Check loss
                    if math.isnan(batch_loss):
                        log('Loss exploded to %.05f at step %d!' % (batch_loss, step))
                        raise Exception('Loss Exploded')

                    # TODO: Check sumamry
                    if step % args.summary_interval == 0:
                        log('Writing summary at step: %d' % step)
                        summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step)

                    # TODO: Check checkpoint
                    if step % args.checkpoint_interval == 0:
                        log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('test acc...')

                        label,final_pred_label = sess.run([
                            model.y, model.preds],feed_dict=feed_dict)



                        log('label.shape           :'+str(label.shape)) # (batch_size , label_length)
                        log('final_pred_label.shape:'+str(np.asarray(final_pred_label).shape)) # (1, batch_size, decode_length<=label_length)

                        log('label           : '+str(label[0]))
                        log('final_pred_label: '+str( np.asarray(final_pred_label)[0]))


                    # TODO: serving
                    if args.serving :#and total_step // hp.steps_per_epoch > 5:
                        np.save('logdir/lm_pinyin_dict.npy',feeder.pny_vocab)
                        np.save('logdir/lm_hanzi_dict.npy',feeder.han_vocab)
                        print(total_step, 'hhhhhhhh')
                        # TODO: Set up serving builder and signature map
                        serve_dir = args.serving_dir + '0001'
                        if os.path.exists(serve_dir):
                            os.removedirs(serve_dir)
                        builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir)
                        input = tf.saved_model.utils.build_tensor_info(model.x)
                        output_labels = tf.saved_model.utils.build_tensor_info(model.preds)

                        prediction_signature = (
                            tf.saved_model.signature_def_utils.build_signature_def(
                                inputs={'pinyin': input},
                                outputs={'hanzi': output_labels},
                                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
                        )
                        if first_serving:
                            first_serving = False
                            builder.add_meta_graph_and_variables(
                                sess=sess, tags=[tf.saved_model.tag_constants.SERVING],
                                signature_def_map={
                                    'predict_Pinyin2Hanzi':
                                        prediction_signature,
                                },
                                main_op=tf.tables_initializer(),
                                strip_default_attrs=True
                            )

                        builder.save()
                        log('Done store serving-model')
                        raise Exception('Done store serving-model')

                    # TODO: Validation
                    # if total_step % hp.steps_per_epoch == 0 and  i >= 10:
                    if total_step % hp.steps_per_epoch == 0:
                        log('validation...')
                        valid_start = time.time()
                        # TODO: validation
                        valid_hp = copy.deepcopy(hp)
                        print('feature_type: ',hp.feature_type)
                        valid_hp.data_type = 'dev'
                        valid_hp.thchs30 = True
                        valid_hp.aishell = True
                        valid_hp.prime = True
                        valid_hp.stcmd = True
                        valid_hp.shuffle = True
                        valid_hp.data_length = None

                        valid_feeder = DataFeeder(args=valid_hp)
                        valid_feeder.pny_vocab = feeder.pny_vocab
                        valid_feeder.han_vocab = feeder.han_vocab
                        # valid_feeder.am_vocab = feeder.am_vocab
                        valid_batch_data = valid_feeder.get_lm_batch()
                        log('valid_num_sentences:' + str(len(valid_feeder.wav_lst))) # 15219
                        valid_hp.input_vocab_size = len(valid_feeder.pny_vocab)
                        valid_hp.final_output_dim = len(valid_feeder.pny_vocab)
                        valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size
                        log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951
                        log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124
                        valid_hp.label_vocab_size = len(valid_feeder.han_vocab)
                        log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327

                        # dev 只跑一个epoch就行
                        with tf.variable_scope('validation') as scope:
                            for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size):
                                valid_input_batch,valid_label_batch = next(valid_batch_data)
                                valid_feed_dict = {
                                    model.x: valid_input_batch,
                                    model.y: valid_label_batch,
                                }
                                # TODO: Run one step
                                valid_start_time = time.time()
                                valid_batch_loss,valid_batch_acc = sess.run([model.mean_loss,model.acc], feed_dict=valid_feed_dict)
                                valid_time_window.append(time.time() - valid_start_time)
                                valid_loss_window.append(valid_batch_loss)
                                valid_acc_window.append(valid_batch_acc)
                                # print('loss',loss,'batch_loss',batch_loss)
                                message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, valid_acc=%.05f, avg_acc=%.05f]' % (
                                    valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_batch_acc,valid_acc_window.average)
                                log(message)
                                summary_writer_dev.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step)
                                valid_step += 1
                            log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start))

        except Exception as e:
            log('Exiting due to exception: %s' % e)
            traceback.print_exc()
Example #5
0
def generate_feed_dict(config,tower_data):
    train_mod = math.ceil(len(tower_data['input_ids']) / config['model']['gpu_num'])
    for k in range(config['model']['gpu_num']):
        start = k * train_mod
        end = start + train_mod
        for key in tower_data:
            print(np.array(tower_data[key][start:end]).shape)

config ={'data':{'batch_size':10,
                      'train_dataset_filePath':'/datastore/liu121/bert_trail/train_data/train_data_%d.pkl',
                      'train_file_num':5
                      },
         'model':{'gpu_num':2}}

for i in range(config['data']['train_file_num']):
    df = DataFeeder(config,i)
    dataset = df.dataset_generator()
    print('total dataset size: ',np.array(df.train_input_ids).shape[0])
    exit()
    for input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels in dataset:
        print('batch size: ',np.array(input_ids).shape)
        tower_data = {'input_ids': input_ids,
                      'input_mask': input_mask,
                      'segment_ids': segment_ids,
                      'masked_lm_positions': masked_lm_positions,
                      'masked_lm_ids': masked_lm_ids,
                      'masked_lm_weights': masked_lm_weights,
                      'next_sentence_labels': next_sentence_labels}
        # generate_feed_dict(config,tower_data)
        print('=====================================')
    exit()
infolog.set_tf_log(load_path)  # Estimator --> log 저장
tf.logging.set_verbosity(tf.logging.INFO)  # 이게 있어야 train log가 출력된다.

# load data
inputs, targets, word_to_index, index_to_word, VOCAB_SIZE, INPUT_LENGTH, OUTPUT_LENGTH = load_data(
    hp)  # (50000, 29), (50000, 12)
hp.add_hparam('VOCAB_SIZE', VOCAB_SIZE)
hp.add_hparam('INPUT_LENGTH', INPUT_LENGTH)  # 29
hp.add_hparam('OUTPUT_LENGTH', OUTPUT_LENGTH)  # 11

train_input, test_input, train_target, test_target = train_test_split(
    inputs, targets, test_size=0.1, random_state=13371447)

datfeeder = DataFeeder(train_input,
                       train_target,
                       test_input,
                       test_target,
                       batch_size=hp.BATCH_SIZE,
                       num_epoch=hp.NUM_EPOCHS)


def seq_accuracy(
    labels, predictions
):  #tf.estimator.EstimatorSpec에서 넘겨 받는 값 + features + labels  ---> 이런 값들을 argument로 받을 수 있다.
    return {
        'seq_accuracy':
        tf.metrics.mean(
            tf.reduce_prod(tf.cast(
                tf.equal(predictions['predition'], labels[:, 1:]), tf.float16),
                           axis=-1))
    }
Example #7
0
config = tf.ConfigProto()
# occupy gpu gracefully
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    # load graph
    new_saver = tf.train.import_meta_graph('log3/model.ckpt-1010600.meta')
    # load values
    new_saver.restore(sess, tf.train.latest_checkpoint('log3/'))

    # nodes
    graph = tf.get_default_graph()
    content_tensor = graph.get_tensor_by_name('IteratorGetNext:2')
    dan_tensor = graph.get_tensor_by_name('fc/fc_dan/BiasAdd:0')
    cdssm_tensor = graph.get_tensor_by_name('fc/fc_cdssm/BiasAdd:0')

    feeder = DataFeeder('train_hrs_video.txt')
    with open('pred_v3.txt', 'w') as f:
        for i, (key, dan, cdssm, content) in enumerate(feeder.read_feature()):
            dan = np.array(dan).astype(int)
            cdssm = np.array(cdssm).astype(int)
            content = np.expand_dims(content, 0)
            dan_v, cdssm_v = sess.run([dan_tensor, cdssm_tensor],
                                      feed_dict={content_tensor: content})
            dan_v = np.round(np.minimum(255, np.maximum(0,
                                                        dan_v[0]))).astype(int)
            cdssm_v = np.round(np.minimum(255,
                                          np.maximum(0,
                                                     cdssm_v[0]))).astype(int)
            f.write(key + '\t' + '\t'.join(str(x) for x in dan_v) + '\t' +
                    '\t'.join(str(x) for x in cdssm_v) + '\n')
            if i % 10000 == 0:
Example #8
0
def train(logdir,args):

    # TODO:parse  ckpt,arguments,hparams
    checkpoint_path = os.path.join(logdir,'model.ckpt')
    input_path = args.data_dir
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from : %s ' % input_path)
    log('Using model : %s' %args.model)

    # TODO:set up datafeeder
    with tf.variable_scope('datafeeder') as scope:
        hp.aishell=True
        hp.prime=False
        hp.stcmd=False
        hp.data_path = 'D:/pycharm_proj/corpus_zn/'
        hp.initial_learning_rate = 0.001
        hp.decay_learning_rate=False
        hp.data_length = 512
        hp.batch_size = 64
        feeder = DataFeeder(args=hp)
        log('num_wavs:'+str(len(feeder.wav_lst))) # 283600
        hp.input_vocab_size = len(feeder.pny_vocab)
        hp.final_output_dim = len(feeder.pny_vocab)
        hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size
        log('steps_per_epoch:' + str(hp.steps_per_epoch))  # 17725
        log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292
        hp.label_vocab_size = len(feeder.han_vocab)
        log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291

    # TODO:set up model
    global_step = tf.Variable(initial_value=0,name='global_step',trainable=False)
    valid_step = 0
    # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model,hp)
        model.build_graph()
        model.add_loss()
        model.add_decoder()
        model.add_optimizer(global_step=global_step)
        # TODO: summary
        stats = add_stats(model=model)
        valid_stats = add_dev_stats(model)


    # TODO:Set up saver and Bookkeeping
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    wer_window = ValueWindow(100)
    valid_time_window = ValueWindow(100)
    valid_loss_window = ValueWindow(100)
    valid_wer_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=20)
    first_serving = True
    # TODO: train
    with tf.Session() as sess:
        try:
            # TODO: Set writer and initializer
            summary_writer = tf.summary.FileWriter(logdir, sess.graph)
            sess.run(tf.global_variables_initializer())

            # TODO: Restore
            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s ' % restore_path)
            else:
                log('Starting new training run ')

            step = 0
            # TODO: epochs steps batch
            for i in range(args.epochs):
                batch_data = feeder.get_am_batch()
                log('Traning epoch '+ str(i)+':')
                for j in range(hp.steps_per_epoch):
                    input_batch = next(batch_data)
                    feed_dict = {model.inputs:input_batch['the_inputs'],
                                 model.labels:input_batch['the_labels'],
                                 model.input_lengths:input_batch['input_length'],
                                 model.label_lengths:input_batch['label_length']}
                    # TODO: Run one step
                    start_time = time.time()
                    total_step, array_loss, batch_loss,opt = sess.run([global_step, model.ctc_loss,
                                                            model.batch_loss,model.optimize],feed_dict=feed_dict)
                    time_window.append(time.time() - start_time)
                    step = total_step

                    # TODO: Append loss
                    # loss = np.sum(array_loss).item()/hp.batch_size
                    loss_window.append(batch_loss)

                    # print('loss',loss,'batch_loss',batch_loss)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, lr=%.07f]' % (
                        step, time_window.average, batch_loss, loss_window.average,K.get_value(model.learning_rate))
                    log(message)
                    # ctcloss返回值是[batch_size,1]形式的所以刚开始没sum报错only size-1 arrays can be converted to Python scalars

                    # TODO: Check loss
                    if math.isnan(batch_loss):
                        log('Loss exploded to %.05f at step %d!' % (batch_loss, step))
                        raise Exception('Loss Exploded')

                    # TODO: Check sumamry
                    if step % args.summary_interval == 0:
                        log('Writing summary at step: %d' % step)
                        summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step)

                    # TODO: Check checkpoint
                    if step % args.checkpoint_interval == 0:
                        log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('test acc...')
                        # eval_start_time = time.time()
                        # with tf.name_scope('eval') as scope:
                        #     with open(os.path.expanduser('~/my_asr2/datasets/resource/preprocessedData/dev-meta.txt'), encoding='utf-8') as f:
                        #         metadata = [line.strip().split('|') for line in f]
                        #         random.shuffle(metadata)
                        #     eval_loss = []
                        #     batch_size = args.hp.batch_size
                        #     batchs = len(metadata)//batch_size
                        #     for i in range(batchs):
                        #         batch = metadata[i*batch_size : i*batch_size+batch_size]
                        #         batch = list(map(eval_get_example,batch))
                        #         batch = eval_prepare_batch(batch)
                        #         feed_dict = {'labels':batch[0],}
                        #     label,final_pred_label ,log_probabilities = sess.run([
                        #         model.labels[0], model.decoded[0], model.log_probabilities[0]])
                        #     # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840
                        #     print('label:            ' ,label)
                        #     print('final_pred_label: ', final_pred_label[0])
                        #     log('eval time: %.03f, avg_eval_loss: %.05f' % (time.time()-eval_start_time,np.mean(eval_loss)))

                        label,final_pred_label ,log_probabilities,y_pred2 = sess.run([
                            model.labels, model.decoded, model.log_probabilities,model.y_pred2],feed_dict=feed_dict)
                        # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840
                        print('label.shape           :',label.shape) # (batch_size , label_length)
                        print('final_pred_label.shape:',np.asarray(final_pred_label).shape) # (1, batch_size, decode_length<=label_length)
                        print('y_pred2.shape         : ', y_pred2.shape)
                        print('label:            ' ,label[0])
                        print('y_pred2         : ', y_pred2[0])
                        print('final_pred_label: ', np.asarray(final_pred_label)[0][0])

                        #  刚开始打不出来,因为使用的tf.nn.ctc_beam_decoder,这个返回的是sparse tensor所以打不出来
                        #  后来用keras封装的decoder自动将sparse转成dense tensor才能打出来

                        # waveform = audio.inv_spectrogram(spectrogram.T)
                        # audio.save_wav(waveform, os.path.join(logdir, 'step-%d-audio.wav' % step))
                        # plot.plot_alignment(alignment, os.path.join(logdir, 'step-%d-align.png' % step),
                        #                     info='%s, %s, %s, step=%d, loss=%.5f' % (
                        #                     args.model, commit, time_string(), step, loss))
                        # log('Input: %s' % sequence_to_text(input_seq))

                    # TODO: Check stop
                    if step % hp.steps_per_epoch ==0:
                        # TODO: Set up serving builder and signature map
                        serve_dir = args.serving_dir + '_' + str(total_step//hp.steps_per_epoch -1)
                        if os.path.exists(serve_dir):
                            os.removedirs(serve_dir)
                        builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir)
                        input_spec = tf.saved_model.utils.build_tensor_info(model.inputs)
                        input_len = tf.saved_model.utils.build_tensor_info(model.input_lengths)
                        output_labels = tf.saved_model.utils.build_tensor_info(model.decoded2)
                        output_logits = tf.saved_model.utils.build_tensor_info(model.pred_logits)
                        prediction_signature = (
                            tf.saved_model.signature_def_utils.build_signature_def(
                                inputs={'spec': input_spec, 'len': input_len},
                                outputs={'label': output_labels, 'logits': output_logits},
                                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
                        )
                        if first_serving:
                            first_serving = False
                            builder.add_meta_graph_and_variables(
                                sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'],
                                signature_def_map={
                                    'predict_AudioSpec2Pinyin':
                                        prediction_signature,
                                },
                                main_op=tf.tables_initializer(),
                                strip_default_attrs=True
                            )
                        else:
                            builder.add_meta_graph_and_variables(
                                sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'],
                                signature_def_map={
                                    'predict_AudioSpec2Pinyin':
                                        prediction_signature,
                                },
                                strip_default_attrs=True
                            )

                        builder.save()
                        log('Done store serving-model')

                    # TODO: Validation
                    if step % hp.steps_per_epoch ==0 and i >= 10:
                        log('validation...')
                        valid_start = time.time()
                        # TODO: validation
                        valid_hp = hp
                        valid_hp.data_type = 'dev'
                        valid_hp.thchs30 = True
                        valid_hp.aishell = False
                        valid_hp.prime = False
                        valid_hp.stcmd = False
                        valid_hp.shuffle = True
                        valid_hp.data_length = None

                        valid_feeder = DataFeeder(args=valid_hp)
                        valid_batch_data = valid_feeder.get_am_batch()
                        log('valid_num_wavs:' + str(len(valid_feeder.wav_lst))) # 15219
                        valid_hp.input_vocab_size = len(valid_feeder.pny_vocab)
                        valid_hp.final_output_dim = len(valid_feeder.pny_vocab)
                        valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size
                        log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951
                        log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124
                        valid_hp.label_vocab_size = len(valid_feeder.han_vocab)
                        log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327
                        words_num = 0
                        word_error_num = 0

                        # dev 只跑一个epoch就行
                        with tf.variable_scope('validation') as scope:
                            for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size):
                                valid_input_batch = next(valid_batch_data)
                                valid_feed_dict = {model.inputs: valid_input_batch['the_inputs'],
                                             model.labels: valid_input_batch['the_labels'],
                                             model.input_lengths: valid_input_batch['input_length'],
                                             model.label_lengths: valid_input_batch['label_length']}
                                # TODO: Run one step
                                valid_start_time = time.time()
                                valid_batch_loss,valid_WER = sess.run([model.batch_loss,model.WER], feed_dict=valid_feed_dict)
                                valid_time_window.append(time.time() - valid_start_time)
                                valid_loss_window.append(valid_batch_loss)
                                valid_wer_window.append(valid_WER)
                                # print('loss',loss,'batch_loss',batch_loss)
                                message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, WER=%.05f, avg_WER=%.05f, lr=%.07f]' % (
                                    valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_WER,valid_wer_window.average,K.get_value(model.learning_rate))
                                log(message)
                                summary_writer.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step)
                                valid_step += 1
                            log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start))

        except Exception as e:
            log('Exiting due to exception: %s' % e)
            traceback.print_exc()
Example #9
0
File: train.py Project: mutiann/ccc
def train(args):
    if args.model_path is None:
        msg = 'Prepare for new run ...'
        output_dir = os.path.join(
            args.log_dir, args.run_name + '_' +
            datetime.datetime.now().strftime('%m%d_%H%M'))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        ckpt_dir = os.path.join(
            args.ckpt_dir, args.run_name + '_' +
            datetime.datetime.now().strftime('%m%d_%H%M'))
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
    else:
        msg = 'Restart previous run ...\nlogs to save to %s, ckpt to save to %s, model to load from %s' % \
                     (args.log_dir, args.ckpt_dir, args.model_path)
        output_dir = args.log_dir
        ckpt_dir = args.ckpt_dir
        if not os.path.isdir(output_dir):
            print('Invalid log dir: %s' % output_dir)
            return
        if not os.path.isdir(ckpt_dir):
            print('Invalid ckpt dir: %s' % ckpt_dir)
            return

    set_logger(os.path.join(output_dir, 'outputs.log'))
    logging.info(msg)

    global device
    if args.device is not None:
        logging.info('Setting device to ' + args.device)
        device = torch.device(args.device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logging.info('Setting up...')
    hparams.parse(args.hparams)
    logging.info(hparams_debug_string())

    model = EdgeClassification

    if hparams.use_roberta:
        logging.info('Using Roberta...')
        model = RobertaEdgeClassification

    global_step = 0

    if args.model_path is None:
        if hparams.load_pretrained:
            logging.info('Load online pretrained model...' + (
                ('cached at ' +
                 args.cache_path) if args.cache_path is not None else ''))
            if hparams.use_roberta:
                model = model.from_pretrained('roberta-base',
                                              cache_dir=args.cache_path,
                                              hparams=hparams)
            else:
                model = model.from_pretrained('bert-base-uncased',
                                              cache_dir=args.cache_path,
                                              hparams=hparams)
        else:
            logging.info('Build model from scratch...')
            if hparams.use_roberta:
                config = BertConfig.from_pretrained('bert-base-uncased')
            else:
                config = RobertaConfig.from_pretrained('roberta-base')
            model = model(config=config, hparams=hparams)
    else:
        if not os.path.isdir(args.model_path):
            raise OSError(str(args.model_path) + ' not found')
        logging.info('Load saved model from %s ...' % (args.model_path))
        model = model.from_pretrained(args.model_path, hparams=hparams)
        step = args.model_path.split('_')[-1]
        if step.isnumeric():
            global_step = int(step)
            logging.info('Initial step=%d' % global_step)

    if hparams.use_roberta:
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    else:
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    hparams.parse(args.hparams)
    logging.info(hparams_debug_string())

    if hparams.text_sample_eval:
        if args.eval_text_path is None:
            raise ValueError('eval_text_path not given')
        if ':' not in args.eval_text_path:
            eval_data_paths = [args.eval_text_path]
        else:
            eval_data_paths = args.eval_text_path.split(':')
        eval_feeder = []
        for p in eval_data_paths:
            name = os.path.split(p)[-1]
            if name.endswith('.tsv'):
                name = name[:-4]
            eval_feeder.append(
                (name, ExternalTextFeeder(p, hparams, tokenizer, 'dev')))
    else:
        eval_feeder = [('', DataFeeder(args.data_dir, hparams, tokenizer,
                                       'dev'))]

    tb_writer = SummaryWriter(output_dir)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        hparams.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=hparams.learning_rate,
                      eps=hparams.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=hparams.warmup_steps,
        lr_decay_step=hparams.lr_decay_step,
        max_lr_decay_rate=hparams.max_lr_decay_rate)

    acc_step = global_step * hparams.gradient_accumulation_steps
    time_window = ValueWindow()
    loss_window = ValueWindow()
    acc_window = ValueWindow()
    model.to(device)
    model.zero_grad()
    tr_loss = tr_acc = 0.0
    start_time = time.time()

    if args.model_path is not None:
        logging.info('Load saved model from %s ...' % (args.model_path))
        if os.path.exists(os.path.join(args.model_path, 'optimizer.pt')) \
                and os.path.exists(os.path.join(args.model_path, 'scheduler.pt')):
            optimizer.load_state_dict(
                torch.load(os.path.join(args.model_path, 'optimizer.pt')))
            optimizer.load_state_dict(optimizer.state_dict())
            scheduler.load_state_dict(
                torch.load(os.path.join(args.model_path, 'scheduler.pt')))
            scheduler.load_state_dict(scheduler.state_dict())
        else:
            logging.warning('Could not find saved optimizer/scheduler')

    if global_step > 0:
        logs = run_eval(args, model, eval_feeder)
        for key, value in logs.items():
            tb_writer.add_scalar(key, value, global_step)

    logging.info('Start training...')
    if hparams.text_sample_train:
        train_feeder = PrebuiltTrainFeeder(args.train_text_path, hparams,
                                           tokenizer, 'train')
    else:
        train_feeder = DataFeeder(args.data_dir, hparams, tokenizer, 'train')

    while True:
        batch = train_feeder.next_batch()
        model.train()

        outputs = model(input_ids=batch.input_ids.to(device),
                        attention_mask=batch.input_mask.to(device),
                        token_type_ids=None if batch.token_type_ids is None
                        else batch.token_type_ids.to(device),
                        labels=batch.labels.to(device))
        loss = outputs['loss']
        preds = outputs['preds']

        acc = torch.mean((preds.cpu() == batch.labels).float())
        preds = preds.cpu().detach().numpy()
        labels = batch.labels.detach().numpy()
        t_acc = np.sum(np.logical_and(preds == 1, labels
                                      == 1)) / np.sum(labels == 1)
        f_acc = np.sum(np.logical_and(preds == 0, labels
                                      == 0)) / np.sum(labels == 0)

        if hparams.gradient_accumulation_steps > 1:
            loss = loss / hparams.gradient_accumulation_steps
            acc = acc / hparams.gradient_accumulation_steps

        tr_loss += loss.item()
        tr_acc += acc.item()
        loss.backward()
        acc_step += 1

        if acc_step % hparams.gradient_accumulation_steps != 0:
            continue

        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       hparams.max_grad_norm)
        optimizer.step()
        scheduler.step(None)
        model.zero_grad()
        global_step += 1

        step_time = time.time() - start_time
        time_window.append(step_time)
        loss_window.append(tr_loss)
        acc_window.append(tr_acc)

        if global_step % args.save_steps == 0:
            # Save model checkpoint
            model_to_save = model.module if hasattr(model, 'module') else model
            cur_ckpt_dir = os.path.join(ckpt_dir,
                                        'checkpoint_%d' % (global_step))
            if not os.path.exists(cur_ckpt_dir):
                os.makedirs(cur_ckpt_dir)
            model_to_save.save_pretrained(cur_ckpt_dir)
            torch.save(args, os.path.join(cur_ckpt_dir, 'training_args.bin'))
            torch.save(optimizer.state_dict(),
                       os.path.join(cur_ckpt_dir, 'optimizer.pt'))
            torch.save(scheduler.state_dict(),
                       os.path.join(cur_ckpt_dir, 'scheduler.pt'))
            logging.info("Saving model checkpoint to %s", cur_ckpt_dir)

        if global_step % args.logging_steps == 0:
            logs = run_eval(args, model, eval_feeder)

            learning_rate_scalar = scheduler.get_lr()[0]
            logs['learning_rate'] = learning_rate_scalar
            logs['loss'] = loss_window.average
            logs['acc'] = acc_window.average

            for key, value in logs.items():
                tb_writer.add_scalar(key, value, global_step)

        message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f, t_acc=%.05f, f_acc=%.05f]' % (
            global_step, step_time, tr_loss, loss_window.average, tr_acc,
            acc_window.average, t_acc, f_acc)
        logging.info(message)
        tr_loss = tr_acc = 0.0
        start_time = time.time()