Ejemplo n.º 1
0
def train(logdir,args):

    # TODO:parse  ckpt,arguments,hparams
    checkpoint_path = os.path.join(logdir,'model.ckpt')
    input_path = args.data_dir
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from : %s ' % input_path)
    log('Using model : %s' %args.model)

    # TODO:set up datafeeder
    with tf.variable_scope('datafeeder') as scope:
        hp.data_length = None
        hp.initial_learning_rate = 0.0001
        hp.batch_size = 256
        hp.prime = True
        hp.stcmd = True
        feeder = DataFeeder(args=hp)
        log('num_sentences:'+str(len(feeder.wav_lst))) # 283600
        hp.input_vocab_size = len(feeder.pny_vocab)
        hp.final_output_dim = len(feeder.pny_vocab)
        hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size
        log('steps_per_epoch:' + str(hp.steps_per_epoch))  # 17725
        log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292
        hp.label_vocab_size = len(feeder.han_vocab)
        log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291

    # TODO:set up model
    global_step = tf.Variable(initial_value=0,name='global_step',trainable=False)
    valid_step = 0
    # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model,hp)
        model.build_graph()
        model.add_loss()
        model.add_optimizer(global_step=global_step,loss=model.mean_loss)
        # TODO: summary
        stats = add_stats(model=model)
        valid_stats = add_dev_stats(model)


    # TODO:Set up saver and Bookkeeping
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    acc_window = ValueWindow(100)
    valid_time_window = ValueWindow(100)
    valid_loss_window = ValueWindow(100)
    valid_acc_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=20)
    first_serving = True
    # TODO: train
    with tf.Session() as sess:

        log(hparams_debug_string(hp))
        try:
            # TODO: Set writer and initializer
            summary_writer = tf.summary.FileWriter(logdir + '/train', sess.graph)
            summary_writer_dev = tf.summary.FileWriter(logdir + '/dev')
            sess.run(tf.global_variables_initializer())

            # TODO: Restore
            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s ' % restore_path)
            else:
                log('Starting new training run ')

            step = 0
            # TODO: epochs steps batch
            for i in range(args.epochs):
                batch_data = feeder.get_lm_batch()
                log('Traning epoch '+ str(i)+':')
                for j in range(hp.steps_per_epoch):
                    input_batch, label_batch = next(batch_data)
                    feed_dict = {
                        model.x:input_batch,
                        model.y:label_batch,
                    }
                    # TODO: Run one step ~~~
                    start_time = time.time()
                    total_step,batch_loss,batch_acc,opt = sess.run([global_step, model.mean_loss,model.acc,model.optimize],feed_dict=feed_dict)
                    time_window.append(time.time() - start_time)
                    step = total_step

                    # TODO: Append loss
                    loss_window.append(batch_loss)
                    acc_window.append(batch_acc)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f,acc=%.05f, avg_acc=%.05f,  lr=%.07f]' % (
                        step, time_window.average, batch_loss, loss_window.average,batch_acc,acc_window.average,K.get_value(model.learning_rate))
                    log(message)

                    # TODO: Check loss
                    if math.isnan(batch_loss):
                        log('Loss exploded to %.05f at step %d!' % (batch_loss, step))
                        raise Exception('Loss Exploded')

                    # TODO: Check sumamry
                    if step % args.summary_interval == 0:
                        log('Writing summary at step: %d' % step)
                        summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step)

                    # TODO: Check checkpoint
                    if step % args.checkpoint_interval == 0:
                        log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('test acc...')

                        label,final_pred_label = sess.run([
                            model.y, model.preds],feed_dict=feed_dict)



                        log('label.shape           :'+str(label.shape)) # (batch_size , label_length)
                        log('final_pred_label.shape:'+str(np.asarray(final_pred_label).shape)) # (1, batch_size, decode_length<=label_length)

                        log('label           : '+str(label[0]))
                        log('final_pred_label: '+str( np.asarray(final_pred_label)[0]))


                    # TODO: serving
                    if args.serving :#and total_step // hp.steps_per_epoch > 5:
                        np.save('logdir/lm_pinyin_dict.npy',feeder.pny_vocab)
                        np.save('logdir/lm_hanzi_dict.npy',feeder.han_vocab)
                        print(total_step, 'hhhhhhhh')
                        # TODO: Set up serving builder and signature map
                        serve_dir = args.serving_dir + '0001'
                        if os.path.exists(serve_dir):
                            os.removedirs(serve_dir)
                        builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir)
                        input = tf.saved_model.utils.build_tensor_info(model.x)
                        output_labels = tf.saved_model.utils.build_tensor_info(model.preds)

                        prediction_signature = (
                            tf.saved_model.signature_def_utils.build_signature_def(
                                inputs={'pinyin': input},
                                outputs={'hanzi': output_labels},
                                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
                        )
                        if first_serving:
                            first_serving = False
                            builder.add_meta_graph_and_variables(
                                sess=sess, tags=[tf.saved_model.tag_constants.SERVING],
                                signature_def_map={
                                    'predict_Pinyin2Hanzi':
                                        prediction_signature,
                                },
                                main_op=tf.tables_initializer(),
                                strip_default_attrs=True
                            )

                        builder.save()
                        log('Done store serving-model')
                        raise Exception('Done store serving-model')

                    # TODO: Validation
                    # if total_step % hp.steps_per_epoch == 0 and  i >= 10:
                    if total_step % hp.steps_per_epoch == 0:
                        log('validation...')
                        valid_start = time.time()
                        # TODO: validation
                        valid_hp = copy.deepcopy(hp)
                        print('feature_type: ',hp.feature_type)
                        valid_hp.data_type = 'dev'
                        valid_hp.thchs30 = True
                        valid_hp.aishell = True
                        valid_hp.prime = True
                        valid_hp.stcmd = True
                        valid_hp.shuffle = True
                        valid_hp.data_length = None

                        valid_feeder = DataFeeder(args=valid_hp)
                        valid_feeder.pny_vocab = feeder.pny_vocab
                        valid_feeder.han_vocab = feeder.han_vocab
                        # valid_feeder.am_vocab = feeder.am_vocab
                        valid_batch_data = valid_feeder.get_lm_batch()
                        log('valid_num_sentences:' + str(len(valid_feeder.wav_lst))) # 15219
                        valid_hp.input_vocab_size = len(valid_feeder.pny_vocab)
                        valid_hp.final_output_dim = len(valid_feeder.pny_vocab)
                        valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size
                        log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951
                        log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124
                        valid_hp.label_vocab_size = len(valid_feeder.han_vocab)
                        log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327

                        # dev 只跑一个epoch就行
                        with tf.variable_scope('validation') as scope:
                            for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size):
                                valid_input_batch,valid_label_batch = next(valid_batch_data)
                                valid_feed_dict = {
                                    model.x: valid_input_batch,
                                    model.y: valid_label_batch,
                                }
                                # TODO: Run one step
                                valid_start_time = time.time()
                                valid_batch_loss,valid_batch_acc = sess.run([model.mean_loss,model.acc], feed_dict=valid_feed_dict)
                                valid_time_window.append(time.time() - valid_start_time)
                                valid_loss_window.append(valid_batch_loss)
                                valid_acc_window.append(valid_batch_acc)
                                # print('loss',loss,'batch_loss',batch_loss)
                                message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, valid_acc=%.05f, avg_acc=%.05f]' % (
                                    valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_batch_acc,valid_acc_window.average)
                                log(message)
                                summary_writer_dev.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step)
                                valid_step += 1
                            log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start))

        except Exception as e:
            log('Exiting due to exception: %s' % e)
            traceback.print_exc()
Ejemplo n.º 2
0
def test():
    parser = argparse.ArgumentParser()
    # TODO: add arguments
    parser.add_argument('--log_dir',
                        default=os.path.expanduser('~/my_asr2/logdir/logging'))
    parser.add_argument(
        '--serving_dir',
        default=os.path.expanduser('~/my_asr2/logdir/serving_am/'))
    parser.add_argument('--data_dir',
                        default=os.path.expanduser('~/corpus_zn'))
    parser.add_argument('--model', default='ASR_wavnet')
    # parser.add_argument('--epochs', type=int, help='Max epochs to run.', default=100)
    parser.add_argument('--restore_step',
                        type=int,
                        help='Global step to restore from checkpoint.',
                        default=2100)
    parser.add_argument('--serving', type=bool, help='', default=False)
    # parser.add_argument('--validation_interval', type=int, help='一个epoch验证5次,每次200步共3200条数据', default=7090) # 35450//5
    parser.add_argument('--summary_interval',
                        type=int,
                        default=1,
                        help='Steps between running summary ops.')
    # parser.add_argument('--checkpoint_interval', type=int, default=100, help='Steps between writing checkpoints.')
    parser.add_argument(
        '--hparams',
        default='',
        help=
        'Hyperparameter overrides as a comma-separated list of name=value pairs'
    )
    args = parser.parse_args()

    run_name = args.model
    logdir = os.path.join(args.log_dir, 'logs-%s' % run_name)
    init(os.path.join(logdir, 'test.log'), run_name)
    hp.parse(args.hparams)

    # TODO:parse  ckpt,arguments,hparams
    checkpoint_path = os.path.join(logdir, 'model.ckpt')
    input_path = args.data_dir
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from : %s ' % input_path)
    log('Using model : %s' % args.model)

    # TODO:set up datafeeder
    with tf.variable_scope('datafeeder') as scope:
        hp.data_type = 'test'
        hp.feature_type = 'mfcc'
        hp.data_length = None
        hp.initial_learning_rate = 0.0005
        hp.batch_size = 1
        hp.aishell = False
        hp.prime = False
        hp.stcmd = False
        hp.AM = True
        hp.LM = False
        hp.shuffle = False
        hp.is_training = False  # TODO: 在infer的时候一定要设置为False否则bn会扰乱所有的值!
        feeder = DataFeeder_wavnet(args=hp)
        log('num_wavs:' + str(len(feeder.wav_lst)))

        feeder.am_vocab = np.load('logdir/am_pinyin_dict.npy').tolist()
        hp.input_vocab_size = len(feeder.am_vocab)
        hp.final_output_dim = len(feeder.am_vocab)
        hp.steps_per_epoch = len(feeder.wav_lst) // hp.batch_size
        log('steps_per_epoch:' + str(hp.steps_per_epoch))
        log('pinyin_vocab_size:' + str(hp.input_vocab_size))

    # TODO: set up model
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hp)
        model.build_graph()
        model.add_loss()
        model.add_decoder()
        # model.add_optimizer(global_step=global_step)
        # TODO: summary
        stats = add_stats(model)

    # TODO:Set up saver and Bookkeeping
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    wer_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=20)

    # TODO: test
    with tf.Session(graph=tf.get_default_graph()) as sess:

        log(hparams_debug_string(hp))
        try:
            # TODO: Set writer and initializer
            summary_writer = tf.summary.FileWriter(logdir + '/test',
                                                   sess.graph)
            sess.run(tf.global_variables_initializer())

            # TODO: Restore
            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s ' % restore_path)
            else:
                log('Starting new training run ')

            # TODO: epochs steps batch
            step = 0
            batch_data = feeder.get_am_batch()
            for j in range(hp.steps_per_epoch):
                input_batch = next(batch_data)
                feed_dict = {
                    model.inputs: input_batch['the_inputs'],
                    model.labels: input_batch['the_labels'],
                    model.input_lengths: input_batch['input_length'],
                    model.label_lengths: input_batch['label_length']
                }
                # TODO: Run one step
                start_time = time.time()
                array_loss, batch_loss, wer, label, final_pred_label = sess.run(
                    [
                        model.ctc_loss, model.batch_loss, model.WER,
                        model.labels, model.decoded1
                    ],
                    feed_dict=feed_dict)
                time_window.append(time.time() - start_time)
                step = step + 1

                # TODO: Append loss
                loss_window.append(batch_loss)
                wer_window.append(wer)
                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, wer=%.05f, avg_wer=%.05f]' % (
                    step, time_window.average, batch_loss, loss_window.average,
                    wer, wer_window.average)
                log(message)
                # TODO: show pred and write summary
                log('label.shape           :' +
                    str(label.shape))  # (batch_size , label_length)
                log('final_pred_label.shape:' +
                    str(np.asarray(final_pred_label).shape))
                log('label           : ' + str(label[0]))
                log('final_pred_label: ' +
                    str(np.asarray(final_pred_label)[0][0]))

                log('Writing summary at step: %d' % step)
                summary_writer.add_summary(
                    sess.run(stats, feed_dict=feed_dict), step)

                # TODO: Check loss
                if math.isnan(batch_loss):
                    log('Loss exploded to %.05f at step %d!' %
                        (batch_loss, step))
                    raise Exception('Loss Exploded')

            log('serving step: ' + str(step))
            # TODO: Set up serving builder and signature map
            serve_dir = args.serving_dir + '0001'
            if os.path.exists(serve_dir):
                shutil.rmtree(serve_dir)
                log('delete exists dirs:' + serve_dir)
            builder = tf.saved_model.builder.SavedModelBuilder(
                export_dir=serve_dir)
            input_spec = tf.saved_model.utils.build_tensor_info(model.inputs)
            input_len = tf.saved_model.utils.build_tensor_info(
                model.input_lengths)
            output_labels = tf.saved_model.utils.build_tensor_info(
                model.decoded1)
            output_logits = tf.saved_model.utils.build_tensor_info(
                model.pred_softmax)
            prediction_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        'mfcc': input_spec,
                        'len': input_len
                    },
                    outputs={
                        'label': output_labels,
                        'logits': output_logits
                    },
                    method_name=tf.saved_model.signature_constants.
                    PREDICT_METHOD_NAME))

            builder.add_meta_graph_and_variables(
                sess=sess,
                tags=[tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    'predict_AudioSpec2Pinyin': prediction_signature,
                },
                main_op=tf.tables_initializer(),
                strip_default_attrs=False)
            builder.save()
            log('Done store serving-model')

        except Exception as e:
            log('Exiting due to exception: %s' % e)
            traceback.print_exc()
Ejemplo n.º 3
0
    def train(self, dataset_train, dataset_test, dataset_train_lengths,
              dataset_test_lengths):

        # Setup data loaders

        with tf.variable_scope('train_iterator'):
            self.iterator_data_train = dataset_train.make_initializable_iterator(
            )
            self.iterator_length_train = dataset_train_lengths.make_initializable_iterator(
            )
            next_train_data = self.iterator_data_train.get_next()
            next_train_length = self.iterator_length_train.get_next()

        with tf.variable_scope('test_iterator'):
            self.iterator_data_test = dataset_test.make_initializable_iterator(
            )
            self.iterator_length_test = dataset_test_lengths.make_initializable_iterator(
            )
            next_test_data = self.iterator_data_test.get_next()
            next_test_length = self.iterator_length_test.get_next()

        # Set up model
        self.initializers = [
            i.initializer for i in [
                self.iterator_data_test, self.iterator_data_train,
                self.iterator_length_test, self.iterator_length_train
            ]
        ]
        self.model, self.stats = self.model_train_mode(next_train_data,
                                                       next_train_length,
                                                       self.global_step)
        self.eval_model = self.model_eval_mode(next_test_data,
                                               next_test_length)

        if self.all_params.use_ema:
            self.saver = create_shadow_saver(self.model, self.global_step)

        else:
            self.saver = tf.train.Saver(max_to_keep=100)

        # Book keeping
        step = 0
        time_window = ValueWindow(100)
        loss_window = ValueWindow(100)

        print('Training set to a maximum of {} steps'.format(self.all_params.train_steps)) \
            # Memory allocation on the GPU as needed

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        # Train
        print('Starting training')
        with tf.Session(config=config) as sess:
            summary_writer = tf.summary.FileWriter(self.tensorboard_dir,
                                                   sess.graph)
            # Allow the full trace to be stored at run time.
            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # Create a fresh metadata object:
            run_metadata = tf.RunMetadata()

            sess.run(tf.global_variables_initializer())
            for init in self.initializers:
                sess.run(init)

            # saved model restoring
            if self.all_params.restore:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(
                        self.save_dir)

                    if (checkpoint_state
                            and checkpoint_state.model_checkpoint_path):
                        print('Loading checkpoint {}'.format(
                            checkpoint_state.model_checkpoint_path))
                        self.saver.restore(
                            sess, checkpoint_state.model_checkpoint_path)

                    else:
                        print('No model to load at {}'.format(self.save_dir))
                        self.saver.save(sess,
                                        self.checkpoint_path,
                                        global_step=self.global_step)

                except tf.errors.OutOfRangeError as e:
                    print('Cannot restore checkpoint: {}'.format(e))
            else:
                print('Starting new training!')
                self.saver.save(sess,
                                self.checkpoint_path,
                                global_step=self.global_step)

            # Training loop
            while not self.coord.should_stop(
            ) and step < self.all_params.train_steps:
                start_time = time.time()
                step, loss, opt = sess.run(
                    [self.global_step, self.model.loss, self.model.optimize],
                    options=options,
                    run_metadata=run_metadata)
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average)
                print(message)

                if np.isnan(loss) or loss > 100.:
                    print('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')

                if step % self.all_params.summary_interval == 0:
                    print('\nWriting summary at step {}'.format(step))
                    summary_writer.add_summary(sess.run(self.stats), step)

                if step % self.all_params.checkpoint_interval == 0 or step == self.all_params.train_steps:
                    print('Saving model!')
                    # Save model and current global step
                    self.saver.save(sess,
                                    self.checkpoint_path,
                                    global_step=self.global_step)

                if step % self.all_params.eval_interval == 0:
                    # Run eval and save eval stats
                    print('\nRunning evaluation at step {}'.format(step))

                    all_logits = []
                    all_outputs = []
                    all_targets = []
                    all_lengths = []
                    val_losses = []

                    for i in tqdm(range(4)):
                        val_loss, logits, outputs, targets, lengths = sess.run(
                            [
                                self.eval_model.loss, self.eval_model.logits,
                                self.eval_model.outputs,
                                self.eval_model.targets,
                                self.eval_model.input_lengths
                            ])

                        all_logits.append(logits)
                        all_outputs.append(outputs)
                        all_targets.append(targets)
                        all_lengths.append(lengths)
                        val_losses.append(val_loss)

                    logits = [l for logits in all_logits for l in logits]
                    outputs = [o for output in all_outputs for o in output]
                    targets = [t for target in all_targets for t in target]
                    lengths = [l for length in all_lengths for l in length]

                    logits = np.array([
                        e for o, l in zip(logits, lengths) for e in o[:l]
                    ]).reshape(-1)
                    outputs = np.array([
                        e for o, l in zip(outputs, lengths) for e in o[:l]
                    ]).reshape(-1)
                    targets = np.array([
                        e for t, l in zip(targets, lengths) for e in t[:l]
                    ]).reshape(-1)

                    val_loss = sum(val_losses) / len(val_losses)

                    assert len(targets) == len(outputs)
                    capture_rate, fig_path = evaluate_and_plot(
                        outputs,
                        targets,
                        index=np.arange(0, len(targets)),
                        model_name=self.all_params.name
                        or self.all_params.self.model,
                        weight=self.all_params.capture_weight,
                        out_dir=self.eval_dir,
                        use_tf=False,
                        sess=sess,
                        step=step)

                    add_eval_stats(summary_writer, step, val_loss,
                                   capture_rate)

                    tensorboard_file = os.path.join(
                        self.tensorboard_dir,
                        os.listdir(self.tensorboard_dir)[0])

                    ###### Replace these lines ###########################
                    print(f'train_loss:  {float(loss)}')
                    print(f'validation_loss{float(val_loss)}')
                    print(f'validation_capture_rate{float(capture_rate)}')

                    ######################################################

            print('Training complete after {} global steps!'.format(
                self.all_params.train_steps))
Ejemplo n.º 4
0
def train(train_loader, model, device, mels_criterion, stop_criterion,
          optimizer, scheduler, writer, train_dir):
    batch_time = ValueWindow()
    data_time = ValueWindow()
    losses = ValueWindow()

    # switch to train mode
    model.train()

    end = time.time()
    global global_epoch
    global global_step
    for i, (txts, mels, stop_tokens, txt_lengths,
            mels_lengths) in enumerate(train_loader):
        scheduler.adjust_learning_rate(optimizer, global_step)
        # measure data loading time
        data_time.update(time.time() - end)

        if device > -1:
            txts = txts.cuda(device)
            mels = mels.cuda(device)
            stop_tokens = stop_tokens.cuda(device)
            txt_lengths = txt_lengths.cuda(device)
            mels_lengths = mels_lengths.cuda(device)

        # compute output
        frames, decoder_frames, stop_tokens_predict, alignment = model(
            txts, txt_lengths, mels)
        decoder_frames_loss = mels_criterion(decoder_frames,
                                             mels,
                                             lengths=mels_lengths)
        frames_loss = mels_criterion(frames, mels, lengths=mels_lengths)
        stop_token_loss = stop_criterion(stop_tokens_predict,
                                         stop_tokens,
                                         lengths=mels_lengths)
        loss = decoder_frames_loss + frames_loss + stop_token_loss

        #print(frames_loss, decoder_frames_loss)
        losses.update(loss.item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if hparams.clip_thresh > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.get_trainable_parameters(), hparams.clip_thresh)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % hparams.print_freq == 0:
            log('Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                    global_epoch,
                    i,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses))

        # Logs
        writer.add_scalar("loss", float(loss.item()), global_step)
        writer.add_scalar(
            "avg_loss in {} window".format(losses.get_dinwow_size),
            float(losses.avg), global_step)
        writer.add_scalar("stop_token_loss", float(stop_token_loss.item()),
                          global_step)
        writer.add_scalar("decoder_frames_loss",
                          float(decoder_frames_loss.item()), global_step)
        writer.add_scalar("output_frames_loss", float(frames_loss.item()),
                          global_step)

        if hparams.clip_thresh > 0:
            writer.add_scalar("gradient norm", grad_norm, global_step)
        writer.add_scalar("learning rate", optimizer.param_groups[0]['lr'],
                          global_step)
        global_step += 1

    dst_alignment_path = join(train_dir,
                              "{}_alignment.png".format(global_step))
    alignment = alignment.cpu().detach().numpy()
    plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]],
                   dst_alignment_path,
                   info="{}, {}".format(hparams.builder, global_step))
Ejemplo n.º 5
0
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
          rank, group_name, hparams, refine_from):
    """Training and validation logging results to tensorboard and stdout

    Params
    ------
    output_directory (string): directory to save checkpoints
    log_directory (string) directory to save tensorboard logs
    checkpoint_path(string): checkpoint path
    n_gpus (int): number of gpus
    rank (int): rank of current gpu
    hparams (object): comma separated list of "name=value" pairs.
    """
    if hparams.distributed_run:
        init_distributed(hparams, n_gpus, rank, group_name)

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    model = load_model(hparams)
    learning_rate = hparams.initial_learning_rate
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                 weight_decay=hparams.weight_decay)

    if hparams.use_GAN and hparams.GAN_type=='lsgan':
        from discriminator import Lsgan_Loss, Calculate_Discrim
        model_D = Calculate_Discrim(hparams).cuda() if torch.cuda.is_available() else Calculate_Discrim(hparams)
        lsgan_loss = Lsgan_Loss(hparams)
        optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate,
                                 weight_decay=hparams.weight_decay)
    
    if hparams.use_GAN and hparams.GAN_type=='wgan-gp':
        from discriminator import Wgan_GP, GP
        model_D = Wgan_GP(hparams).cuda() if torch.cuda.is_available() else Wgan_GP(hparams)
        calc_gradient_penalty = GP(hparams).cuda() if torch.cuda.is_available() else GP(hparams)
        optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate,
                                 weight_decay=hparams.weight_decay)   

    if hparams.is_partial_refine:
        refine_list=['speaker_embedding.weight',
        'spkemb_projection.weight',
        'spkemb_projection.bias',
        'projection.weight',
        'projection.bias',
        'encoder.encoders.0.norm1.w.weight',
        'encoder.encoders.0.norm1.w.bias',
        'encoder.encoders.0.norm1.b.weight',
        'encoder.encoders.0.norm1.b.bias',
        'encoder.encoders.0.norm2.w.weight',
        'encoder.encoders.0.norm2.w.bias',
        'encoder.encoders.0.norm2.b.weight',
        'encoder.encoders.0.norm2.b.bias',
        'encoder.encoders.1.norm1.w.weight',
        'encoder.encoders.1.norm1.w.bias',
        'encoder.encoders.1.norm1.b.weight',
        'encoder.encoders.1.norm1.b.bias',
        'encoder.encoders.1.norm2.w.weight',
        'encoder.encoders.1.norm2.w.bias',
        'encoder.encoders.1.norm2.b.weight',
        'encoder.encoders.1.norm2.b.bias',
        'encoder.encoders.2.norm1.w.weight',
        'encoder.encoders.2.norm1.w.bias',
        'encoder.encoders.2.norm1.b.weight',
        'encoder.encoders.2.norm1.b.bias',
        'encoder.encoders.2.norm2.w.weight',
        'encoder.encoders.2.norm2.w.bias',
        'encoder.encoders.2.norm2.b.weight',
        'encoder.encoders.2.norm2.b.bias',
        'encoder.encoders.3.norm1.w.weight',
        'encoder.encoders.3.norm1.w.bias',
        'encoder.encoders.3.norm1.b.weight',
        'encoder.encoders.3.norm1.b.bias',
        'encoder.encoders.3.norm2.w.weight',
        'encoder.encoders.3.norm2.w.bias',
        'encoder.encoders.3.norm2.b.weight',
        'encoder.encoders.3.norm2.b.bias',
        'encoder.encoders.4.norm1.w.weight',
        'encoder.encoders.4.norm1.w.bias',
        'encoder.encoders.4.norm1.b.weight',
        'encoder.encoders.4.norm1.b.bias',
        'encoder.encoders.4.norm2.w.weight',
        'encoder.encoders.4.norm2.w.bias',
        'encoder.encoders.4.norm2.b.weight',
        'encoder.encoders.4.norm2.b.bias',
        'encoder.encoders.5.norm1.w.weight',
        'encoder.encoders.5.norm1.w.bias',
        'encoder.encoders.5.norm1.b.weight',
        'encoder.encoders.5.norm1.b.bias',
        'encoder.encoders.5.norm2.w.weight',
        'encoder.encoders.5.norm2.w.bias',
        'encoder.encoders.5.norm2.b.weight',
        'encoder.encoders.5.norm2.b.bias',
        'encoder.after_norm.w.weight',
        'encoder.after_norm.w.bias',
        'encoder.after_norm.b.weight',
        'encoder.after_norm.b.bias',
        'duration_predictor.norm.0.w.weight',
        'duration_predictor.norm.0.w.bias',
        'duration_predictor.norm.0.b.weight',
        'duration_predictor.norm.0.b.bias',
        'duration_predictor.norm.1.w.weight',
        'duration_predictor.norm.1.w.bias',
        'duration_predictor.norm.1.b.weight',
        'duration_predictor.norm.1.b.bias',
        'decoder.encoders.0.norm1.w.weight',
        'decoder.encoders.0.norm1.w.bias',
        'decoder.encoders.0.norm1.b.weight',
        'decoder.encoders.0.norm1.b.bias',
        'decoder.encoders.0.norm2.w.weight',
        'decoder.encoders.0.norm2.w.bias',
        'decoder.encoders.0.norm2.b.weight',
        'decoder.encoders.0.norm2.b.bias',
        'decoder.encoders.1.norm1.w.weight',
        'decoder.encoders.1.norm1.w.bias',
        'decoder.encoders.1.norm1.b.weight',
        'decoder.encoders.1.norm1.b.bias',
        'decoder.encoders.1.norm2.w.weight',
        'decoder.encoders.1.norm2.w.bias',
        'decoder.encoders.1.norm2.b.weight',
        'decoder.encoders.1.norm2.b.bias',
        'decoder.encoders.2.norm1.w.weight',
        'decoder.encoders.2.norm1.w.bias',
        'decoder.encoders.2.norm1.b.weight',
        'decoder.encoders.2.norm1.b.bias',
        'decoder.encoders.2.norm2.w.weight',
        'decoder.encoders.2.norm2.w.bias',
        'decoder.encoders.2.norm2.b.weight',
        'decoder.encoders.2.norm2.b.bias',
        'decoder.encoders.3.norm1.w.weight',
        'decoder.encoders.3.norm1.w.bias',
        'decoder.encoders.3.norm1.b.weight',
        'decoder.encoders.3.norm1.b.bias',
        'decoder.encoders.3.norm2.w.weight',
        'decoder.encoders.3.norm2.w.bias',
        'decoder.encoders.3.norm2.b.weight',
        'decoder.encoders.3.norm2.b.bias',
        'decoder.encoders.4.norm1.w.weight',
        'decoder.encoders.4.norm1.w.bias',
        'decoder.encoders.4.norm1.b.weight',
        'decoder.encoders.4.norm1.b.bias',
        'decoder.encoders.4.norm2.w.weight',
        'decoder.encoders.4.norm2.w.bias',
        'decoder.encoders.4.norm2.b.weight',
        'decoder.encoders.4.norm2.b.bias',
        'decoder.encoders.5.norm1.w.weight',
        'decoder.encoders.5.norm1.w.bias',
        'decoder.encoders.5.norm1.b.weight',
        'decoder.encoders.5.norm1.b.bias',
        'decoder.encoders.5.norm2.w.weight',
        'decoder.encoders.5.norm2.w.bias',
        'decoder.encoders.5.norm2.b.weight',
        'decoder.encoders.5.norm2.b.bias',
        'decoder.after_norm.w.weight',
        'decoder.after_norm.w.bias',
        'decoder.after_norm.b.weight',
        'decoder.after_norm.b.bias']
        if hparams.is_refine_style:
            style_list= ['gst.ref_enc.convs.0.weight',
        'gst.ref_enc.convs.1.weight',
        'gst.ref_enc.convs.1.bias',
        'gst.ref_enc.convs.3.weight',
        'gst.ref_enc.convs.4.weight',
        'gst.ref_enc.convs.4.bias',
        'gst.ref_enc.convs.6.weight',
        'gst.ref_enc.convs.7.weight',
        'gst.ref_enc.convs.7.bias',
        'gst.ref_enc.convs.9.weight',
        'gst.ref_enc.convs.10.weight',
        'gst.ref_enc.convs.10.bias',
        'gst.ref_enc.convs.12.weight',
        'gst.ref_enc.convs.13.weight',
        'gst.ref_enc.convs.13.bias',
        'gst.ref_enc.convs.15.weight',
        'gst.ref_enc.convs.16.weight',
        'gst.ref_enc.convs.16.bias',
        'gst.ref_enc.gru.weight_ih_l0,'
        'gst.ref_enc.gru.weight_hh_l0',
        'gst.ref_enc.gru.bias_ih_l0',
        'gst.ref_enc.gru.bias_hh_l0',
        'gst.stl.gst_embs',
        'gst.stl.mha.linear_q.weight',
        'gst.stl.mha.linear_q.bias',
        'gst.stl.mha.linear_k.weight',
        'gst.stl.mha.linear_k.bias',
        'gst.stl.mha.linear_v.weight',
        'gst.stl.mha.linear_v.bias',
        'gst.stl.mha.linear_out.weight',
        'gst.stl.mha.linear_out.bias',
        'gst.choosestl.choose_mha.linear_q.weight',
        'gst.choosestl.choose_mha.linear_q.bias',
        'gst.choosestl.choose_mha.linear_k.weight',
        'gst.choosestl.choose_mha.linear_k.bias',
        'gst.choosestl.choose_mha.linear_v.weight',
        'gst.choosestl.choose_mha.linear_v.bias',
        'gst.choosestl.choose_mha.linear_out.weight',
        'gst.choosestl.choose_mha.linear_out.bias',
        'gst_projection.weight',
        'gst_projection.bias'
        ]
        refine_list += style_list
    
    for name, param in model.named_parameters():
        if hparams.is_partial_refine:
            if name in refine_list:
                param.requires_grad = True 
            else:
                param.requires_grad = False
        print(name, param.requires_grad, param.shape)

    if hparams.distributed_run:
        model = apply_gradient_allreduce(model)
        if hparams.use_GAN:
            model_D = apply_gradient_allreduce(model_D)

    logger = prepare_directories_and_logger(output_directory, log_directory, rank)

    train_loader, valset, collate_fn, trainset = prepare_dataloaders(hparams)

    # Load checkpoint if one exists
    iteration = 0
    epoch_offset = 0
    if not checkpoint_path:
        checkpoint_path = get_checkpoint_path(output_directory) if not hparams.is_partial_refine else get_checkpoint_path(refine_from)
    if checkpoint_path is not None:
        if warm_start:
            model = warm_start_model(
                checkpoint_path, model, hparams.ignore_layers)
        else:
            model, optimizer, _learning_rate, iteration = load_checkpoint(
                checkpoint_path, model, optimizer, hparams, style_list=style_list if hparams.is_refine_style else None)
            if hparams.use_saved_learning_rate:
                learning_rate = _learning_rate
            iteration = (iteration + 1)  if not hparams.is_partial_refine else 0# next iteration is iteration + 1
            epoch_offset = max(0, int(iteration / len(train_loader))) if not hparams.is_partial_refine else 0

    model.train()
    if hparams.use_GAN:
        model_D.train()
    else:
        hparams.use_GAN = True
        hparams.Generator_pretrain_step = hparams.iters
    is_overflow = False
    epoch = epoch_offset
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    # ================ MAIN TRAINNIG LOOP! ===================
    while iteration <= hparams.iters:
        # print("Epoch: {}".format(epoch))
        if hparams.distributed_run and hparams.batch_criterion == 'utterance':
            train_loader.sampler.set_epoch(epoch)
        for i, batch in enumerate(train_loader):
            start = time.perf_counter()
            learning_rate = learning_rate_decay(iteration, hparams)
            if hparams.use_GAN:

                # Discriminator turn
                if iteration > hparams.Generator_pretrain_step:
                    for param_group in optimizer_D.param_groups:
                        param_group['lr'] = learning_rate
                    optimizer.zero_grad()
                    optimizer_D.zero_grad()
                    for name, param in model.named_parameters():
                        param.requires_grad = False
                    for name, param in model_D.named_parameters():
                        param.requires_grad = True 

                    loss, loss_dict, weight, pred_outs, ys, olens = model(*model._parse_batch(batch,hparams,utt_mels=trainset.utt_mels if hparams.is_refine_style else None))
                    if hparams.GAN_type=='lsgan':
                        discrim_gen_output, discrim_target_output = model_D(pred_outs + (torch.randn(pred_outs.size()).cuda() if hparams.add_noise else 0), ys + (torch.randn(pred_outs.size()).cuda() if hparams.add_noise else 0), olens)
                        loss_D = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='D')
                        loss_G = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='G')                
                        loss_D.backward(retain_graph=True)
                    if hparams.GAN_type=='wgan-gp':
                        D_real = model_D(ys, olens)
                        D_real = -D_real.mean()
                        D_real.backward(retain_graph=True)
                        D_fake = model_D(pred_outs, olens)
                        D_fake = D_fake.mean()
                        D_fake.backward()
                        gradient_penalty = calc_gradient_penalty(model_D, ys.data, pred_outs.data, olens.data)
                        gradient_penalty.backward()
                        D_cost = D_real + D_fake + gradient_penalty
                        Wasserstein_D = -D_real - D_fake
                    grad_norm_D = torch.nn.utils.clip_grad_norm_(model_D.parameters(), hparams.grad_clip_thresh)
                    optimizer_D.step()
                    print('\n')
                    if hparams.GAN_type=='lsgan':
                        print("Epoch:{} step:{} loss_D: {:>9.6f}, loss_G: {:>9.6f}, Grad Norm: {:>9.6f}".format(epoch, iteration, loss_D, loss_G, grad_norm_D))
                    if hparams.GAN_type=='wgan-gp':
                        print("Epoch:{} step:{} D_cost: {:>9.6f}, Wasserstein_D: {:>9.6f}, GP: {:>9.6f}, Grad Norm: {:>9.6f}".format(epoch, iteration, D_cost, Wasserstein_D, gradient_penalty, grad_norm_D))

                # Generator turn
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rate 
                optimizer.zero_grad()
                if iteration > hparams.Generator_pretrain_step:
                    for name, param in model.named_parameters():
                        if hparams.is_partial_refine:
                            if name in refine_list:
                                param.requires_grad = True
                        else:
                            param.requires_grad = True                                 
                    for name, param in model_D.named_parameters():
                        param.requires_grad = False
                    optimizer_D.zero_grad()
                loss, loss_dict, weight, pred_outs, ys, olens = model(*model._parse_batch(batch,hparams,utt_mels=trainset.utt_mels if hparams.is_refine_style else None))
                if iteration > hparams.Generator_pretrain_step:
                    if hparams.GAN_type=='lsgan':
                        discrim_gen_output, discrim_target_output = model_D(pred_outs, ys, olens)
                        loss_D = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='D')
                        loss_G = lsgan_loss(discrim_gen_output, discrim_target_output, train_object='G')
                    if hparams.GAN_type=='wgan-gp':
                        loss_G = model_D(pred_outs, olens)
                        loss_G = -loss_G.mean()
                    loss = loss + loss_G*hparams.GAN_alpha*abs(loss.item()/loss_G.item())
                if hparams.distributed_run:
                    reduced_loss = reduce_tensor(loss.data, n_gpus).item()
                    if loss_dict:
                        for key in loss_dict:
                            loss_dict[key] = reduce_tensor(loss_dict[key].data, n_gpus).item()
                else:
                    reduced_loss = loss.item()
                    if loss_dict:
                        for key in loss_dict:
                            loss_dict[key] = loss_dict[key].item()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
                optimizer.step()
                duration = time.perf_counter() - start
                time_window.append(duration)
                loss_window.append(reduced_loss)
                if not is_overflow and (rank == 0):
                    if iteration % hparams.log_per_checkpoint == 0:
                        if hparams.GAN_type=='lsgan':
                            print("Epoch:{} step:{} Train loss: {:>9.6f}, avg loss: {:>9.6f}, Grad Norm: {:>9.6f}, {:>5.2f}s/it, {:s} loss: {:>9.6f}, D_loss: {:>9.6f}, G_loss: {:>9.6f}, duration loss: {:>9.6f}, ssim loss: {:>9.6f}, lr: {:>4}".format(
                            epoch, iteration, reduced_loss, loss_window.average, grad_norm, time_window.average, hparams.loss_type, loss_dict[hparams.loss_type], loss_D.item() if iteration > hparams.Generator_pretrain_step else 0, loss_G.item() if iteration > hparams.Generator_pretrain_step else 0, loss_dict["duration_loss"], loss_dict["ssim_loss"], learning_rate))
                        if hparams.GAN_type=='wgan-gp':
                            print("Epoch:{} step:{} Train loss: {:>9.6f}, avg loss: {:>9.6f}, Grad Norm: {:>9.6f}, {:>5.2f}s/it, {:s} loss: {:>9.6f}, G_loss: {:>9.6f}, duration loss: {:>9.6f}, ssim loss: {:>9.6f}, lr: {:>4}".format(
                            epoch, iteration, reduced_loss, loss_window.average, grad_norm, time_window.average, hparams.loss_type, loss_dict[hparams.loss_type], loss_G.item() if iteration > hparams.Generator_pretrain_step else 0, loss_dict["duration_loss"], loss_dict["ssim_loss"], learning_rate))
                        logger.log_training(
                            reduced_loss, grad_norm, learning_rate, duration, iteration, loss_dict)

            if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
                if valset is not None:
                    validate(model, valset, iteration, hparams.batch_criterion,
                            hparams.batch_size, n_gpus, collate_fn, logger,
                            hparams.distributed_run, rank)
                if rank == 0:
                    checkpoint_path = os.path.join(
                        output_directory, "checkpoint_{}_refine_{}".format(iteration, hparams.training_files.split('/')[-2].split('_')[-1]) if hparams.is_partial_refine else "checkpoint_{}".format(iteration))
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path)

            iteration += 1
            torch.cuda.empty_cache()
        epoch += 1
Ejemplo n.º 6
0
            sess.run(dev_iterator.initializer, {num_frames: max_frames})

            try:
                checkpoint_state = tf.train.get_checkpoint_state(train_hparams.checkpoint_path)

                if (checkpoint_state and checkpoint_state.model_checkpoint_path):
                    log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path), slack=True)
                    saver.restore(sess, checkpoint_state.model_checkpoint_path)
                else:
                    log('No model to load at {}'.format(train_hparams.checkpoint_path), slack=True)
                    saver.save(sess, checkpoint_path, global_step=train_resnet.global_step)
            except OutOfRangeError as e:
                log('Cannot restore checkpoint: {}'.format(e), slack=True)

            step = 0
            time_window = ValueWindow(100)
            loss_window = ValueWindow(100)
            acc_window = ValueWindow(100)
            while step < train_hparams.train_steps:
                start_time = time.time()

                fetches = [train_resnet.global_step, train_resnet.train_op, train_resnet.cost, train_resnet.accuracy]
                feed_dict = {handle: train_handle}

                try:
                    step, _, loss, acc = sess.run(fetches=fetches, feed_dict=feed_dict)
                except OutOfRangeError as e:
                    sess.run(train_iterator.initializer, {num_frames: max_frames})
                    continue

                time_window.append(time.time() - start_time)
Ejemplo n.º 7
0
def train_epoch(model, train_loader, loss_fn, optimizer, scheduler, batch_size,
                epoch, start_stpe):
    model.train()
    count = 0
    total_loss = 0
    n = batch_size
    step = start_stpe
    examples = []
    total_loss_window = ValueWindow(100)
    post_loss_window = ValueWindow(100)
    post_acc_window = ValueWindow(100)

    for x, y in train_loader:
        count += 1
        examples.append([x[0], y[0]])

        if count % 8 == 0:
            examples.sort(key=lambda x: len(x[-1]))
            examples = (np.vstack([ex[0] for ex in examples]),
                        np.vstack([ex[1] for ex in examples]))
            batches = [(examples[0][i:i + n], examples[1][i:i + n])
                       for i in range(0,
                                      len(examples[-1]) + 1 - n, n)]

            if len(examples[-1]) % n != 0:
                batches.append(
                    (np.vstack((examples[0][-(len(examples[-1]) % n):],
                                examples[0][:n - (len(examples[0]) % n)])),
                     np.vstack((examples[1][-(len(examples[-1]) % n):],
                                examples[1][:n - (len(examples[-1]) % n)]))))

            for batch in batches:  # mini batch
                # train_data(?, 7, 80), train_label(?, 7)
                step += 1
                train_data = torch.as_tensor(batch[0],
                                             dtype=torch.float32).to(DEVICE)
                train_label = torch.as_tensor(batch[1],
                                              dtype=torch.float32).to(DEVICE)

                optimizer.zero_grad(True)
                midnet_output, postnet_output, alpha = model(train_data)
                postnet_accuracy, pipenet_accuracy = prediction(
                    train_label, midnet_output, postnet_output)
                loss, postnet_loss, pipenet_loss, attention_loss = loss_fn(
                    model, train_label, postnet_output, midnet_output, alpha)
                total_loss += loss.detach().item()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 5, norm_type=2)
                optimizer.step()
                scheduler.step()
                lr = scheduler._rate

                total_loss_window.append(loss.detach().item())
                post_loss_window.append(postnet_loss.detach().item())
                post_acc_window.append(postnet_accuracy)
                if step % 10 == 0:
                    print(
                        '{}  Epoch: {}, Step: {}, overall loss: {:.5f}, postnet loss: {:.5f}, '
                        'postnet acc: {:.4f}, lr :{:.5f}'.format(
                            datetime.now().strftime(_format)[:-3], epoch, step,
                            total_loss_window.average,
                            post_loss_window.average, post_acc_window.average,
                            lr))
                if step % 50_000 == 0:
                    print('{} save checkpoint.'.format(
                        datetime.now().strftime(_format)[:-3]))
                    checkpoint = {
                        "model": model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        "epoch": epoch,
                        'step': step,
                        'scheduler_lr': scheduler._rate,
                        'scheduler_step': scheduler._step
                    }
                    if not os.path.isdir("./checkpoint"):
                        os.mkdir("./checkpoint")
                    torch.save(
                        checkpoint, './checkpoint/STAM_weights_%s_%s.pth' %
                        (str(epoch), str(step / 1_000_000)))
                    gc.collect()
                    torch.cuda.empty_cache()
            del batches, examples
            examples = []
Ejemplo n.º 8
0
def train(logdir,args):

    # TODO:parse  ckpt,arguments,hparams
    checkpoint_path = os.path.join(logdir,'model.ckpt')
    input_path = args.data_dir
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from : %s ' % input_path)
    log('Using model : %s' %args.model)

    # TODO:set up datafeeder
    with tf.variable_scope('datafeeder') as scope:
        hp.aishell=True
        hp.prime=False
        hp.stcmd=False
        hp.data_path = 'D:/pycharm_proj/corpus_zn/'
        hp.initial_learning_rate = 0.001
        hp.decay_learning_rate=False
        hp.data_length = 512
        hp.batch_size = 64
        feeder = DataFeeder(args=hp)
        log('num_wavs:'+str(len(feeder.wav_lst))) # 283600
        hp.input_vocab_size = len(feeder.pny_vocab)
        hp.final_output_dim = len(feeder.pny_vocab)
        hp.steps_per_epoch = len(feeder.wav_lst)//hp.batch_size
        log('steps_per_epoch:' + str(hp.steps_per_epoch))  # 17725
        log('pinyin_vocab_size:'+str(hp.input_vocab_size)) # 1292
        hp.label_vocab_size = len(feeder.han_vocab)
        log('label_vocab_size :' + str(hp.label_vocab_size)) # 6291

    # TODO:set up model
    global_step = tf.Variable(initial_value=0,name='global_step',trainable=False)
    valid_step = 0
    # valid_global_step = tf.Variable(initial_value=0,name='valid_global_step',trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model,hp)
        model.build_graph()
        model.add_loss()
        model.add_decoder()
        model.add_optimizer(global_step=global_step)
        # TODO: summary
        stats = add_stats(model=model)
        valid_stats = add_dev_stats(model)


    # TODO:Set up saver and Bookkeeping
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    wer_window = ValueWindow(100)
    valid_time_window = ValueWindow(100)
    valid_loss_window = ValueWindow(100)
    valid_wer_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=20)
    first_serving = True
    # TODO: train
    with tf.Session() as sess:
        try:
            # TODO: Set writer and initializer
            summary_writer = tf.summary.FileWriter(logdir, sess.graph)
            sess.run(tf.global_variables_initializer())

            # TODO: Restore
            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s ' % restore_path)
            else:
                log('Starting new training run ')

            step = 0
            # TODO: epochs steps batch
            for i in range(args.epochs):
                batch_data = feeder.get_am_batch()
                log('Traning epoch '+ str(i)+':')
                for j in range(hp.steps_per_epoch):
                    input_batch = next(batch_data)
                    feed_dict = {model.inputs:input_batch['the_inputs'],
                                 model.labels:input_batch['the_labels'],
                                 model.input_lengths:input_batch['input_length'],
                                 model.label_lengths:input_batch['label_length']}
                    # TODO: Run one step
                    start_time = time.time()
                    total_step, array_loss, batch_loss,opt = sess.run([global_step, model.ctc_loss,
                                                            model.batch_loss,model.optimize],feed_dict=feed_dict)
                    time_window.append(time.time() - start_time)
                    step = total_step

                    # TODO: Append loss
                    # loss = np.sum(array_loss).item()/hp.batch_size
                    loss_window.append(batch_loss)

                    # print('loss',loss,'batch_loss',batch_loss)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, lr=%.07f]' % (
                        step, time_window.average, batch_loss, loss_window.average,K.get_value(model.learning_rate))
                    log(message)
                    # ctcloss返回值是[batch_size,1]形式的所以刚开始没sum报错only size-1 arrays can be converted to Python scalars

                    # TODO: Check loss
                    if math.isnan(batch_loss):
                        log('Loss exploded to %.05f at step %d!' % (batch_loss, step))
                        raise Exception('Loss Exploded')

                    # TODO: Check sumamry
                    if step % args.summary_interval == 0:
                        log('Writing summary at step: %d' % step)
                        summary_writer.add_summary(sess.run(stats,feed_dict=feed_dict), step)

                    # TODO: Check checkpoint
                    if step % args.checkpoint_interval == 0:
                        log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
                        saver.save(sess, checkpoint_path, global_step=step)
                        log('test acc...')
                        # eval_start_time = time.time()
                        # with tf.name_scope('eval') as scope:
                        #     with open(os.path.expanduser('~/my_asr2/datasets/resource/preprocessedData/dev-meta.txt'), encoding='utf-8') as f:
                        #         metadata = [line.strip().split('|') for line in f]
                        #         random.shuffle(metadata)
                        #     eval_loss = []
                        #     batch_size = args.hp.batch_size
                        #     batchs = len(metadata)//batch_size
                        #     for i in range(batchs):
                        #         batch = metadata[i*batch_size : i*batch_size+batch_size]
                        #         batch = list(map(eval_get_example,batch))
                        #         batch = eval_prepare_batch(batch)
                        #         feed_dict = {'labels':batch[0],}
                        #     label,final_pred_label ,log_probabilities = sess.run([
                        #         model.labels[0], model.decoded[0], model.log_probabilities[0]])
                        #     # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840
                        #     print('label:            ' ,label)
                        #     print('final_pred_label: ', final_pred_label[0])
                        #     log('eval time: %.03f, avg_eval_loss: %.05f' % (time.time()-eval_start_time,np.mean(eval_loss)))

                        label,final_pred_label ,log_probabilities,y_pred2 = sess.run([
                            model.labels, model.decoded, model.log_probabilities,model.y_pred2],feed_dict=feed_dict)
                        # 刚开始没有加[]会报错 https://github.com/tensorflow/tensorflow/issues/11840
                        print('label.shape           :',label.shape) # (batch_size , label_length)
                        print('final_pred_label.shape:',np.asarray(final_pred_label).shape) # (1, batch_size, decode_length<=label_length)
                        print('y_pred2.shape         : ', y_pred2.shape)
                        print('label:            ' ,label[0])
                        print('y_pred2         : ', y_pred2[0])
                        print('final_pred_label: ', np.asarray(final_pred_label)[0][0])

                        #  刚开始打不出来,因为使用的tf.nn.ctc_beam_decoder,这个返回的是sparse tensor所以打不出来
                        #  后来用keras封装的decoder自动将sparse转成dense tensor才能打出来

                        # waveform = audio.inv_spectrogram(spectrogram.T)
                        # audio.save_wav(waveform, os.path.join(logdir, 'step-%d-audio.wav' % step))
                        # plot.plot_alignment(alignment, os.path.join(logdir, 'step-%d-align.png' % step),
                        #                     info='%s, %s, %s, step=%d, loss=%.5f' % (
                        #                     args.model, commit, time_string(), step, loss))
                        # log('Input: %s' % sequence_to_text(input_seq))

                    # TODO: Check stop
                    if step % hp.steps_per_epoch ==0:
                        # TODO: Set up serving builder and signature map
                        serve_dir = args.serving_dir + '_' + str(total_step//hp.steps_per_epoch -1)
                        if os.path.exists(serve_dir):
                            os.removedirs(serve_dir)
                        builder = tf.saved_model.builder.SavedModelBuilder(export_dir=serve_dir)
                        input_spec = tf.saved_model.utils.build_tensor_info(model.inputs)
                        input_len = tf.saved_model.utils.build_tensor_info(model.input_lengths)
                        output_labels = tf.saved_model.utils.build_tensor_info(model.decoded2)
                        output_logits = tf.saved_model.utils.build_tensor_info(model.pred_logits)
                        prediction_signature = (
                            tf.saved_model.signature_def_utils.build_signature_def(
                                inputs={'spec': input_spec, 'len': input_len},
                                outputs={'label': output_labels, 'logits': output_logits},
                                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
                        )
                        if first_serving:
                            first_serving = False
                            builder.add_meta_graph_and_variables(
                                sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'],
                                signature_def_map={
                                    'predict_AudioSpec2Pinyin':
                                        prediction_signature,
                                },
                                main_op=tf.tables_initializer(),
                                strip_default_attrs=True
                            )
                        else:
                            builder.add_meta_graph_and_variables(
                                sess=sess, tags=[tf.saved_model.tag_constants.SERVING, 'ASR'],
                                signature_def_map={
                                    'predict_AudioSpec2Pinyin':
                                        prediction_signature,
                                },
                                strip_default_attrs=True
                            )

                        builder.save()
                        log('Done store serving-model')

                    # TODO: Validation
                    if step % hp.steps_per_epoch ==0 and i >= 10:
                        log('validation...')
                        valid_start = time.time()
                        # TODO: validation
                        valid_hp = hp
                        valid_hp.data_type = 'dev'
                        valid_hp.thchs30 = True
                        valid_hp.aishell = False
                        valid_hp.prime = False
                        valid_hp.stcmd = False
                        valid_hp.shuffle = True
                        valid_hp.data_length = None

                        valid_feeder = DataFeeder(args=valid_hp)
                        valid_batch_data = valid_feeder.get_am_batch()
                        log('valid_num_wavs:' + str(len(valid_feeder.wav_lst))) # 15219
                        valid_hp.input_vocab_size = len(valid_feeder.pny_vocab)
                        valid_hp.final_output_dim = len(valid_feeder.pny_vocab)
                        valid_hp.steps_per_epoch = len(valid_feeder.wav_lst) // valid_hp.batch_size
                        log('valid_steps_per_epoch:' + str(valid_hp.steps_per_epoch)) # 951
                        log('valid_pinyin_vocab_size:' + str(valid_hp.input_vocab_size)) # 1124
                        valid_hp.label_vocab_size = len(valid_feeder.han_vocab)
                        log('valid_label_vocab_size :' + str(valid_hp.label_vocab_size)) # 3327
                        words_num = 0
                        word_error_num = 0

                        # dev 只跑一个epoch就行
                        with tf.variable_scope('validation') as scope:
                            for k in range(len(valid_feeder.wav_lst) // valid_hp.batch_size):
                                valid_input_batch = next(valid_batch_data)
                                valid_feed_dict = {model.inputs: valid_input_batch['the_inputs'],
                                             model.labels: valid_input_batch['the_labels'],
                                             model.input_lengths: valid_input_batch['input_length'],
                                             model.label_lengths: valid_input_batch['label_length']}
                                # TODO: Run one step
                                valid_start_time = time.time()
                                valid_batch_loss,valid_WER = sess.run([model.batch_loss,model.WER], feed_dict=valid_feed_dict)
                                valid_time_window.append(time.time() - valid_start_time)
                                valid_loss_window.append(valid_batch_loss)
                                valid_wer_window.append(valid_WER)
                                # print('loss',loss,'batch_loss',batch_loss)
                                message = 'Valid-Step %-7d [%.03f sec/step, valid_loss=%.05f, avg_loss=%.05f, WER=%.05f, avg_WER=%.05f, lr=%.07f]' % (
                                    valid_step, valid_time_window.average, valid_batch_loss, valid_loss_window.average,valid_WER,valid_wer_window.average,K.get_value(model.learning_rate))
                                log(message)
                                summary_writer.add_summary(sess.run(valid_stats,feed_dict=valid_feed_dict), valid_step)
                                valid_step += 1
                            log('Done Validation!Total Time Cost(sec):' + str(time.time()-valid_start))

        except Exception as e:
            log('Exiting due to exception: %s' % e)
            traceback.print_exc()
Ejemplo n.º 9
0
def train(log_dir, args, hparams, use_hvd=False):
    if use_hvd:
        import horovod.tensorflow as hvd
        # Initialize Horovod.
        hvd.init()
    else:
        hvd = None
    eval_dir, eval_plot_dir, eval_wav_dir, meta_folder, plot_dir, save_dir, tensorboard_dir, wav_dir = init_dir(
        log_dir)

    checkpoint_path = os.path.join(save_dir, 'centaur_model.ckpt')
    input_path = os.path.join(args.base_dir, args.input_dir)

    log('Checkpoint path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder'):
        feeder = Feeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    model, stats = model_train_mode(feeder, hparams, global_step, hvd=hvd)
    eval_model = model_test_mode(feeder, hparams)

    # Embeddings metadata
    char_embedding_meta = os.path.join(meta_folder, 'CharacterEmbeddings.tsv')
    if not os.path.isfile(char_embedding_meta):
        with open(char_embedding_meta, 'w', encoding='utf-8') as f:
            for symbol in symbols:
                if symbol == ' ':
                    symbol = '\\s'  # For visual purposes, swap space with \s

                f.write('{}\n'.format(symbol))

    char_embedding_meta = char_embedding_meta.replace(log_dir, '..')
    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=2)

    log('Centaur training set to a maximum of {} steps'.format(
        args.train_steps))

    # Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    if use_hvd:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    # Train
    with tf.Session(config=config) as sess:
        try:

            # Init model and load weights
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                restore_model(saver, sess, global_step, save_dir,
                              checkpoint_path, args.reset_global_step)
            else:
                log('Starting new training!', slack=True)
                saver.save(sess, checkpoint_path, global_step=global_step)

            # initializing feeder
            start_step = sess.run(global_step)
            feeder.start_threads(sess, start_step=start_step)
            # Horovod bcast vars across workers
            if use_hvd:
                # Horovod: broadcast initial variable states from rank 0 to all other processes.
                # This is necessary to ensure consistent initialization of all workers when
                # training is started with random weights or restored from a checkpoint.
                bcast = hvd.broadcast_global_variables(0)
                bcast.run()
                log('Worker{}: Initialized'.format(hvd.rank()))
            # Training loop
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
            while not coord.should_stop() and step < args.train_steps:
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.train_op])
                if use_hvd:
                    main_process = hvd.rank() == 0
                if main_process:
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                        step, time_window.average, loss, loss_window.average)
                    log(message,
                        end='\r',
                        slack=(step % args.checkpoint_interval == 0))

                    if np.isnan(loss) or loss > 100.:
                        log('Loss exploded to {:.5f} at step {}'.format(
                            loss, step))
                        raise Exception('Loss exploded')

                    if step % args.summary_interval == 0:
                        log('\nWriting summary at step {}'.format(step))
                        summary_writer.add_summary(sess.run(stats), step)

                    if step % args.eval_interval == 0:
                        run_eval(args, eval_dir, eval_model, eval_plot_dir,
                                 eval_wav_dir, feeder, hparams, sess, step,
                                 summary_writer)

                    if step % args.checkpoint_interval == 0 or step == args.train_steps or step == 300:
                        save_current_model(args, checkpoint_path, global_step,
                                           hparams, loss, model, plot_dir,
                                           saver, sess, step, wav_dir)

                    if step % args.embedding_interval == 0 or step == args.train_steps or step == 1:
                        update_character_embedding(char_embedding_meta,
                                                   save_dir, summary_writer)

            log('Centaur training complete after {} global steps!'.format(
                args.train_steps),
                slack=True)
            return save_dir

        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 10
0
def train(log_dir, args):
    checkpoint_path = os.path.join(hdfs_ckpts, log_dir, 'model.ckpt')
    log(hp.to_string(), is_print=False)
    log('Loading training data from: %s' % args.tfr_dir)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Using model: sygst tacotron2')

    tf_dset = TFDataSet(hp, args.tfr_dir)
    feats = tf_dset.get_train_next()
    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    training = tf.placeholder_with_default(True, shape=(), name='training')
    with tf.name_scope('model'):
        model = Tacotron2SYGST(hp)
        model(feats['inputs'],
              mel_inputs=feats['mel_targets'],
              spec_inputs=feats['linear_targets'],
              spec_lengths=feats['spec_lengths'],
              ref_inputs=feats['mel_targets'],
              ref_lengths=feats['spec_lengths'],
              arousal_labels=feats['soft_arousal_labels'],
              valence_labels=feats['soft_valance_labels'],
              training=training)
        """
        text_x, mel_x, spec_x, spec_len, aro, val = debug_data(2, 5, 10)
        model(text_x, mel_x, spec_x, spec_len, mel_x, spec_len, aro, val, training=training)
        """
        model.add_loss()
        model.add_optimizer(global_step)
        stats = model.add_stats()

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=50, keep_checkpoint_every_n_hours=2)

    # Train!
    config = tf.ConfigProto(allow_soft_placement=True,
                            gpu_options=tf.GPUOptions(allow_growth=True))
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())
            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%s' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s' % restore_path, slack=True)
            else:
                log('Starting a new training run ...', slack=True)
            """
            fetches = [global_step, model.optimize, model.loss, model.mel_loss, model.spec_loss,
                       model.stop_loss, model.arousal_loss, model.valence_loss, model.mel_grad_norms_max,
                       model.spec_grad_norms_max, model.stop_grad_norms_max, model.aro_grad_norms_max, model.val_grad_norms_max]
            """
            fetches = [
                global_step, model.optimize, model.loss, model.mel_loss,
                model.spec_loss, model.stop_loss, model.arousal_loss,
                model.valence_loss
            ]
            for _ in range(_max_step):
                start_time = time.time()
                sess.run(debug.get_ops())
                # step, _, loss, mel_loss, spec_loss, stop_loss, aro_loss, val_loss, mel_g, spec_g, stop_g, aro_g, val_g = sess.run(fetches)
                step, _, loss, mel_loss, spec_loss, stop_loss, aro_loss, val_loss = sess.run(
                    fetches)
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                """
                message = 'Step %-7d [%.3f sec/step,ml=%.3f,spl=%.3f,sl=%.3f,al=%.3f,vl=%.3f,mg=%.4f,spg=%.4f,sg=%.4f,ag=%.4f,vg=%.4f]' % (
                    step, time_window.average, mel_loss, spec_loss, stop_loss, aro_loss, val_loss, mel_g, spec_g, stop_g, aro_g, val_g)
                """
                message = 'Step %-7d [%.3f sec/step,ml=%.3f,spl=%.3f,sl=%.3f,al=%.3f,vl=%.3f]' % (
                    step, time_window.average, mel_loss, spec_loss, stop_loss,
                    aro_loss, val_loss)
                log(message, slack=(step % args.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.5f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    try:
                        summary_writer.add_summary(sess.run(stats), step)
                    except Exception as e:
                        log(f'summary failed and ignored: {str(e)}')

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    gt_mel, gt_spec, seq, mel, spec, align = sess.run([
                        model.mel_targets[0], model.spec_targets[0],
                        model.text_targets[0], model.mel_outputs[0],
                        model.spec_outputs[0], model.alignment_outputs[0]
                    ])
                    text = sequence_to_text(seq)
                    wav = audio.inv_spectrogram(hp, spec.T)
                    wav_path = os.path.join(log_dir,
                                            'step-%d-audio.wav' % step)
                    mel_path = os.path.join(log_dir, 'step-%d-mel.png' % step)
                    spec_path = os.path.join(log_dir,
                                             'step-%d-spec.png' % step)
                    align_path = os.path.join(log_dir,
                                              'step-%d-align.png' % step)
                    info = '%s, %s, step=%d, loss=%.5f\n %s' % (
                        args.model, time_string(), step, loss, text)
                    plot.plot_alignment(align, align_path, info=info)
                    plot.plot_mel(mel, mel_path, info=info, gt_mel=gt_mel)
                    plot.plot_mel(spec, spec_path, info=info, gt_mel=gt_spec)
                    audio.save_wav(hp, wav, wav_path)
                    log('Input: %s' % text)

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
Ejemplo n.º 11
0
Archivo: train.py Proyecto: mutiann/ccc
def train(args):
    if args.model_path is None:
        msg = 'Prepare for new run ...'
        output_dir = os.path.join(
            args.log_dir, args.run_name + '_' +
            datetime.datetime.now().strftime('%m%d_%H%M'))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        ckpt_dir = os.path.join(
            args.ckpt_dir, args.run_name + '_' +
            datetime.datetime.now().strftime('%m%d_%H%M'))
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
    else:
        msg = 'Restart previous run ...\nlogs to save to %s, ckpt to save to %s, model to load from %s' % \
                     (args.log_dir, args.ckpt_dir, args.model_path)
        output_dir = args.log_dir
        ckpt_dir = args.ckpt_dir
        if not os.path.isdir(output_dir):
            print('Invalid log dir: %s' % output_dir)
            return
        if not os.path.isdir(ckpt_dir):
            print('Invalid ckpt dir: %s' % ckpt_dir)
            return

    set_logger(os.path.join(output_dir, 'outputs.log'))
    logging.info(msg)

    global device
    if args.device is not None:
        logging.info('Setting device to ' + args.device)
        device = torch.device(args.device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logging.info('Setting up...')
    hparams.parse(args.hparams)
    logging.info(hparams_debug_string())

    model = EdgeClassification

    if hparams.use_roberta:
        logging.info('Using Roberta...')
        model = RobertaEdgeClassification

    global_step = 0

    if args.model_path is None:
        if hparams.load_pretrained:
            logging.info('Load online pretrained model...' + (
                ('cached at ' +
                 args.cache_path) if args.cache_path is not None else ''))
            if hparams.use_roberta:
                model = model.from_pretrained('roberta-base',
                                              cache_dir=args.cache_path,
                                              hparams=hparams)
            else:
                model = model.from_pretrained('bert-base-uncased',
                                              cache_dir=args.cache_path,
                                              hparams=hparams)
        else:
            logging.info('Build model from scratch...')
            if hparams.use_roberta:
                config = BertConfig.from_pretrained('bert-base-uncased')
            else:
                config = RobertaConfig.from_pretrained('roberta-base')
            model = model(config=config, hparams=hparams)
    else:
        if not os.path.isdir(args.model_path):
            raise OSError(str(args.model_path) + ' not found')
        logging.info('Load saved model from %s ...' % (args.model_path))
        model = model.from_pretrained(args.model_path, hparams=hparams)
        step = args.model_path.split('_')[-1]
        if step.isnumeric():
            global_step = int(step)
            logging.info('Initial step=%d' % global_step)

    if hparams.use_roberta:
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    else:
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    hparams.parse(args.hparams)
    logging.info(hparams_debug_string())

    if hparams.text_sample_eval:
        if args.eval_text_path is None:
            raise ValueError('eval_text_path not given')
        if ':' not in args.eval_text_path:
            eval_data_paths = [args.eval_text_path]
        else:
            eval_data_paths = args.eval_text_path.split(':')
        eval_feeder = []
        for p in eval_data_paths:
            name = os.path.split(p)[-1]
            if name.endswith('.tsv'):
                name = name[:-4]
            eval_feeder.append(
                (name, ExternalTextFeeder(p, hparams, tokenizer, 'dev')))
    else:
        eval_feeder = [('', DataFeeder(args.data_dir, hparams, tokenizer,
                                       'dev'))]

    tb_writer = SummaryWriter(output_dir)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        hparams.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=hparams.learning_rate,
                      eps=hparams.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=hparams.warmup_steps,
        lr_decay_step=hparams.lr_decay_step,
        max_lr_decay_rate=hparams.max_lr_decay_rate)

    acc_step = global_step * hparams.gradient_accumulation_steps
    time_window = ValueWindow()
    loss_window = ValueWindow()
    acc_window = ValueWindow()
    model.to(device)
    model.zero_grad()
    tr_loss = tr_acc = 0.0
    start_time = time.time()

    if args.model_path is not None:
        logging.info('Load saved model from %s ...' % (args.model_path))
        if os.path.exists(os.path.join(args.model_path, 'optimizer.pt')) \
                and os.path.exists(os.path.join(args.model_path, 'scheduler.pt')):
            optimizer.load_state_dict(
                torch.load(os.path.join(args.model_path, 'optimizer.pt')))
            optimizer.load_state_dict(optimizer.state_dict())
            scheduler.load_state_dict(
                torch.load(os.path.join(args.model_path, 'scheduler.pt')))
            scheduler.load_state_dict(scheduler.state_dict())
        else:
            logging.warning('Could not find saved optimizer/scheduler')

    if global_step > 0:
        logs = run_eval(args, model, eval_feeder)
        for key, value in logs.items():
            tb_writer.add_scalar(key, value, global_step)

    logging.info('Start training...')
    if hparams.text_sample_train:
        train_feeder = PrebuiltTrainFeeder(args.train_text_path, hparams,
                                           tokenizer, 'train')
    else:
        train_feeder = DataFeeder(args.data_dir, hparams, tokenizer, 'train')

    while True:
        batch = train_feeder.next_batch()
        model.train()

        outputs = model(input_ids=batch.input_ids.to(device),
                        attention_mask=batch.input_mask.to(device),
                        token_type_ids=None if batch.token_type_ids is None
                        else batch.token_type_ids.to(device),
                        labels=batch.labels.to(device))
        loss = outputs['loss']
        preds = outputs['preds']

        acc = torch.mean((preds.cpu() == batch.labels).float())
        preds = preds.cpu().detach().numpy()
        labels = batch.labels.detach().numpy()
        t_acc = np.sum(np.logical_and(preds == 1, labels
                                      == 1)) / np.sum(labels == 1)
        f_acc = np.sum(np.logical_and(preds == 0, labels
                                      == 0)) / np.sum(labels == 0)

        if hparams.gradient_accumulation_steps > 1:
            loss = loss / hparams.gradient_accumulation_steps
            acc = acc / hparams.gradient_accumulation_steps

        tr_loss += loss.item()
        tr_acc += acc.item()
        loss.backward()
        acc_step += 1

        if acc_step % hparams.gradient_accumulation_steps != 0:
            continue

        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       hparams.max_grad_norm)
        optimizer.step()
        scheduler.step(None)
        model.zero_grad()
        global_step += 1

        step_time = time.time() - start_time
        time_window.append(step_time)
        loss_window.append(tr_loss)
        acc_window.append(tr_acc)

        if global_step % args.save_steps == 0:
            # Save model checkpoint
            model_to_save = model.module if hasattr(model, 'module') else model
            cur_ckpt_dir = os.path.join(ckpt_dir,
                                        'checkpoint_%d' % (global_step))
            if not os.path.exists(cur_ckpt_dir):
                os.makedirs(cur_ckpt_dir)
            model_to_save.save_pretrained(cur_ckpt_dir)
            torch.save(args, os.path.join(cur_ckpt_dir, 'training_args.bin'))
            torch.save(optimizer.state_dict(),
                       os.path.join(cur_ckpt_dir, 'optimizer.pt'))
            torch.save(scheduler.state_dict(),
                       os.path.join(cur_ckpt_dir, 'scheduler.pt'))
            logging.info("Saving model checkpoint to %s", cur_ckpt_dir)

        if global_step % args.logging_steps == 0:
            logs = run_eval(args, model, eval_feeder)

            learning_rate_scalar = scheduler.get_lr()[0]
            logs['learning_rate'] = learning_rate_scalar
            logs['loss'] = loss_window.average
            logs['acc'] = acc_window.average

            for key, value in logs.items():
                tb_writer.add_scalar(key, value, global_step)

        message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f, t_acc=%.05f, f_acc=%.05f]' % (
            global_step, step_time, tr_loss, loss_window.average, tr_acc,
            acc_window.average, t_acc, f_acc)
        logging.info(message)
        tr_loss = tr_acc = 0.0
        start_time = time.time()
def train(log_dir, args):
    save_dir = os.path.join(log_dir, 'pretrained/')
    checkpoint_path = os.path.join(save_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    log('Checkpoint path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    #Set up feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = Feeder(coord, input_path, hparams)

    #Set up model
    step_count = 0
    try:
        #simple text file to keep count of global step
        with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file:
            step_count = int(file.read())
    except:
        print(
            'no step_counter file found, assuming there is no saved checkpoint'
        )

    global_step = tf.Variable(step_count, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    #Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    #Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    #Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            #saved model restoring
            if args.restore:
                #Restore saved model if the user requested it, Default = True.
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(log_dir)
                except tf.errors.OutOfRangeError as e:
                    log('Cannot restore checkpoint: {}'.format(e))

            if (checkpoint_state and checkpoint_state.model_checkpoint_path):
                log('Loading checkpoint {}'.format(
                    checkpoint_state.model_checkpoint_path))
                saver.restore(sess, checkpoint_state.model_checkpoint_path)

            else:
                if not args.restore:
                    log('Starting new training!')
                else:
                    log('No model to load at {}'.format(save_dir))

            #initiating feeder
            feeder.start_in_session(sess)

            #Training loop
            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average)
                log(message, end='\r')

                if loss > 100 or np.isnan(loss):
                    log('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')

                if step % args.summary_interval == 0:
                    log('\nWriting summary at step: {}'.format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    with open(os.path.join(log_dir, 'step_counter.txt'),
                              'w') as file:
                        file.write(str(step))
                    log('Saving checkpoint to: {}-{}'.format(
                        checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

                    input_seq, prediction = sess.run(
                        [model.inputs[0], model.output[0]])
                    #Save an example prediction at this step
                    log('Input at step {}: {}'.format(
                        step, sequence_to_text(input_seq)))
                    log('Model prediction: {}'.format(
                        class_to_str(prediction)))

        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 13
0
def train():
	checkpoint_path = os.path.join(Config.LogDir, 'model.ckpt')
	save_dir = os.path.join(Config.LogDir, 'pretrained/')
	input_path = Config.DataDir

	#Set up data feeder
	coord = tf.train.Coordinator()
	with tf.variable_scope('datafeeder') as scope:
		feeder = Feeder(coord, input_path)

	#Set up model:
	step_count = 0
	try:
		#simple text file to keep count of global step
		with open('step_counter.txt', 'r') as file:
			step_count = int(file.read())
	except:
		print('no step_counter file found, assuming there is no saved checkpoint')

	global_step = tf.Variable(step_count, name='global_step', trainable=False)

	model = Tacotron2.Tacotron2(global_step, feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.target_lengths)
	model.buildTacotron2()
	model.addLoss(feeder.masks)

	#Book keeping
	step = 0
	time_window = ValueWindow(100)
	loss_window = ValueWindow(100)
	saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

	#Train
	with tf.Session() as sess:
		try:
			sess.run(tf.global_variables_initializer())

			#saver.restore(sess, checkpoint_state.model_checkpoint_path)

			#initiating feeder
			feeder.start_in_session(sess)

			#Training loop
			while not coord.should_stop():
				start_time = time.time()
				step, loss, _ = sess.run([model.global_step, model.loss, model.optim])
				time_window.append(time.time() - start_time)
				loss_window.append(loss)
				if step % 1 == 0:
					message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
						step, time_window.average, loss, loss_window.average)
					print (message)
				'''
				if loss > 100 or np.isnan(loss):
					log('Loss exploded to {:.5f} at step {}'.format(loss, step))
					raise Exception('Loss exploded')

				if step % Config.CheckpointInterval == 0:
					with open('step_counter.txt', 'w') as file:
						file.write(str(step))
					log('Saving checkpoint to: {}-{}'.format(checkpoint_path, step))
					saver.save(sess, checkpoint_path, global_step=step)
					# Unlike the original tacotron, we won't save audio
					# because we yet have to use wavenet as vocoder
					log('Saving alignement..')
					input_seq, prediction, alignment = sess.run([model.inputs[0],
																 model.mel_outputs[0],
																 model.alignments[0],
																 ])
					#save predicted spectrogram to disk (for plot and manual evaluation purposes)
					mel_filename = 'ljspeech-mel-prediction-step-{}.npy'.format(step)
					np.save(os.path.join(log_dir, mel_filename), prediction.T, allow_pickle=False)

					#save alignment plot to disk (evaluation purposes)
					plot.plot_alignment(alignment, os.path.join(log_dir, 'step-{}-align.png'.format(step)),
						info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss))
					log('Input at step {}: {}'.format(step, sequence_to_text(input_seq)))
				'''

		except Exception as e:
			#log('Exiting due to exception: {}'.format(e), slack=True)
			traceback.print_exc()
			coord.request_stop(e)
Ejemplo n.º 14
0
def train(log_dir, args):
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, 'training/train.txt')

    logger.log('Checkpoint path: %s' % checkpoint_path)
    logger.log('Loading training data from: %s' % input_path)

    # set up DataFeeder
    coordi = tf.train.Coordinator()
    with tf.compat.v1.variable_scope('data_feeder'):
        feeder = DataFeeder(coordi, input_path)

    # set up Model
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.compat.v1.variable_scope('model'):
        model = Tacotron()
        model.init(feeder.inputs,
                   feeder.input_lengths,
                   mel_targets=feeder.mel_targets,
                   linear_targets=feeder.linear_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    # book keeping
    step = 0
    loss_window = ValueWindow(100)
    time_window = ValueWindow(100)
    saver = tf.compat.v1.train.Saver(max_to_keep=5,
                                     keep_checkpoint_every_n_hours=2)

    # start training already!
    with tf.compat.v1.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # initialize parameters
            sess.run(tf.compat.v1.global_variables_initializer())

            # if requested, restore from step
            if (args.restore_step):
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                logger.log('Resuming from checkpoint: %s' % restore_path)
            else:
                logger.log('Starting a new training!')

            feeder.start_in_session(sess)

            while not coordi.should_stop():
                start_time = time.time()

                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                msg = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)

                logger.log(msg)

                if loss > 100 or math.isnan(loss):
                    # bad situation
                    logger.log('Loss exploded to %.05f at step %d!' %
                               (loss, step))
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    # it's time to write summary
                    logger.log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    # it's time to save a checkpoint
                    logger.log('Saving checkpoint to: %s-%d' %
                               (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    logger.log('Saving audio and alignment...')

                    input_seq, spectrogram, alignment = sess.run([
                        model.inputs[0], model.linear_outputs[0],
                        model.alignments[0]
                    ])

                    # convert spectrogram to waveform
                    waveform = audio.spectrogram_to_wav(spectrogram.T)
                    # save it
                    audio.save_audio(
                        waveform,
                        os.path.join(log_dir, 'step-%d-audio.wav' % step))

                    plotter.plot_alignment(
                        alignment,
                        os.path.join(log_dir, 'step-%d-align.png' % step),
                        info='%s, %s, step=%d, loss=%.5f' %
                        ('tacotron', time_string(), step, loss))

                    logger.log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            logger.log('Exiting due to exception %s' % e)
            traceback.print_exc()
            coordi.request_stop(e)
Ejemplo n.º 15
0
def train(log_dir, config):
    config.data_paths = config.data_paths  # ['datasets/moon']

    data_dirs = config.data_paths  # ['datasets/moon\\data']
    num_speakers = len(data_dirs)
    config.num_test = config.num_test_per_speaker * num_speakers  # 2*1

    if num_speakers > 1 and hparams.model_type not in [
            "multi-speaker", "simple"
    ]:
        raise Exception("[!] Unkown model_type for multi-speaker: {}".format(
            config.model_type))

    commit = get_git_commit() if config.git else 'None'
    checkpoint_path = os.path.join(
        log_dir, 'model.ckpt'
    )  # 'logdir-tacotron\\moon_2018-08-28_13-06-42\\model.ckpt'

    #log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash())  # hccho: 주석 처리
    log('=' * 50)
    #log(' [*] dit diff:\n%s' % get_git_diff())
    log('=' * 50)
    log(' [*] Checkpoint path: %s' % checkpoint_path)
    log(' [*] Loading training data from: %s' % data_dirs)
    log(' [*] Using model: %s' %
        config.model_dir)  # 'logdir-tacotron\\moon_2018-08-28_13-06-42'
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        # DataFeeder의 6개 placeholder: train_feeder.inputs, train_feeder.input_lengths, train_feeder.loss_coeff, train_feeder.mel_targets, train_feeder.linear_targets, train_feeder.speaker_id
        train_feeder = DataFeederTacotron2(coord,
                                           data_dirs,
                                           hparams,
                                           config,
                                           32,
                                           data_type='train',
                                           batch_size=config.batch_size)
        test_feeder = DataFeederTacotron2(coord,
                                          data_dirs,
                                          hparams,
                                          config,
                                          8,
                                          data_type='test',
                                          batch_size=config.num_test)

    # Set up model:

    global_step = tf.Variable(0, name='global_step', trainable=False)

    with tf.variable_scope('model') as scope:
        model = create_model(hparams)
        model.initialize(inputs=train_feeder.inputs,
                         input_lengths=train_feeder.input_lengths,
                         num_speakers=num_speakers,
                         speaker_id=train_feeder.speaker_id,
                         mel_targets=train_feeder.mel_targets,
                         linear_targets=train_feeder.linear_targets,
                         is_training=True,
                         loss_coeff=train_feeder.loss_coeff,
                         stop_token_targets=train_feeder.stop_token_targets)

        model.add_loss()
        model.add_optimizer(global_step)
        train_stats = add_stats(model, scope_name='train')  # legacy

    with tf.variable_scope('model', reuse=True) as scope:
        test_model = create_model(hparams)
        test_model.initialize(
            inputs=test_feeder.inputs,
            input_lengths=test_feeder.input_lengths,
            num_speakers=num_speakers,
            speaker_id=test_feeder.speaker_id,
            mel_targets=test_feeder.mel_targets,
            linear_targets=test_feeder.linear_targets,
            is_training=False,
            loss_coeff=test_feeder.loss_coeff,
            stop_token_targets=test_feeder.stop_token_targets)

        test_model.add_loss()

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2)

    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    # Train!
    #with tf.Session(config=sess_config) as sess:
    with tf.Session(config=sess_config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if config.load_path:
                # Restore from a checkpoint if the user requested it.
                restore_path = get_most_recent_checkpoint(config.model_dir)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            elif config.initialize_path:
                restore_path = get_most_recent_checkpoint(
                    config.initialize_path)
                saver.restore(sess, restore_path)
                log('Initialized from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)

                zero_step_assign = tf.assign(global_step, 0)
                sess.run(zero_step_assign)

                start_step = sess.run(global_step)
                log('=' * 50)
                log(' [*] Global step is reset to {}'.format(start_step))
                log('=' * 50)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            start_step = sess.run(global_step)

            train_feeder.start_in_session(sess, start_step)
            test_feeder.start_in_session(sess, start_step)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss_without_coeff, model.optimize])

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % config.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % config.summary_interval == 0:
                    log('Writing summary at step: %d' % step)

                    summary_writer.add_summary(sess.run(train_stats), step)

                if step % config.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

                if step % config.test_interval == 0:
                    log('Saving audio and alignment...')
                    num_test = config.num_test

                    fetches = [
                        model.inputs[:num_test],
                        model.linear_outputs[:num_test],
                        model.alignments[:num_test],
                        test_model.inputs[:num_test],
                        test_model.linear_outputs[:num_test],
                        test_model.alignments[:num_test],
                    ]

                    sequences, spectrograms, alignments, test_sequences, test_spectrograms, test_alignments = sess.run(
                        fetches)

                    #librosa는 ffmpeg가 있어야 한다.
                    save_and_plot(
                        sequences[:1], spectrograms[:1], alignments[:1],
                        log_dir, step, loss, "train"
                    )  # spectrograms: (num_test,200,1025), alignments: (num_test,encoder_length,decoder_length)
                    save_and_plot(test_sequences, test_spectrograms,
                                  test_alignments, log_dir, step, loss, "test")

                if step == 50:
                    log("Stop at 300000's step (last loss= %.05f)" %
                        loss_window.average)
                    coord.request_stop(e)

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 16
0
def train(log_dir, config):
    config.data_paths = config.data_paths  # 파싱된 명령행 인자값 중 데이터 경로 : default='datasets/kr_example'

    data_dirs = [os.path.join(data_path, "data") \
            for data_path in config.data_paths]
    num_speakers = len(data_dirs) # 학습하는 화자 수 측정 : 단일화자 모델-1, 다중화자 모델-2
    config.num_test = config.num_test_per_speaker * num_speakers

    if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]:  # 다중화자 모델 학습일 때 모델 타입이 "deepvoice"나 "simple"이 아니라면
        raise Exception("[!] Unkown model_type for multi-speaker: {}".format(config.model_type))  # hparams.modle_type을 config.model_type으로 오타남.

    commit = get_git_commit() if config.git else 'None'  # git 관련된거여서 무시
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')  # checkpoint_path 경로 지정-model.skpt 파일 경로

    log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash())  # git log
    log('='*50)  # 줄 구분용 =====
    #log(' [*] dit diff:\n%s' % get_git_diff())
    log('='*50)  # 줄 구분용 =====
    log(' [*] Checkpoint path: %s' % checkpoint_path)  # check_point 경로 출력
    log(' [*] Loading training data from: %s' % data_dirs)
    log(' [*] Using model: %s' % config.model_dir)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()  # 쓰레드 사용 선언
    with tf.variable_scope('datafeeder') as scope:
        train_feeder = DataFeeder(
                coord, data_dirs, hparams, config, 32,
                data_type='train', batch_size=hparams.batch_size)
        # def __init__(self, coordinator, data_dirs, hparams, config, batches_per_group, data_type, batch_size):
        test_feeder = DataFeeder(
                coord, data_dirs, hparams, config, 8,
                data_type='test', batch_size=config.num_test)

    # Set up model:
    is_randomly_initialized = config.initialize_path is None
    global_step = tf.Variable(0, name='global_step', trainable=False)

    with tf.variable_scope('model') as scope:
        model = create_model(hparams)  # Tacotron 모델 생성
        model.initialize(
                train_feeder.inputs, train_feeder.input_lengths,
                num_speakers,  train_feeder.speaker_id,
                train_feeder.mel_targets, train_feeder.linear_targets,
                train_feeder.loss_coeff,
                is_randomly_initialized=is_randomly_initialized)

        model.add_loss()
        model.add_optimizer(global_step)
        train_stats = add_stats(model, scope_name='stats') # legacy

    with tf.variable_scope('model', reuse=True) as scope:
        test_model = create_model(hparams)  # Tacotron test모델 생성
        test_model.initialize(
                test_feeder.inputs, test_feeder.input_lengths,
                num_speakers, test_feeder.speaker_id,
                test_feeder.mel_targets, test_feeder.linear_targets,
                test_feeder.loss_coeff, rnn_decoder_test_mode=True,
                is_randomly_initialized=is_randomly_initialized)
        test_model.add_loss()

    test_stats = add_stats(test_model, model, scope_name='test')  # model의 loss값같은것들을 tensorboard에 기록 / model에 test_model, model2에 model
    test_stats = tf.summary.merge([test_stats, train_stats])

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)  # ValueWindow 클래스 window_size = 100
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2)  # 2시간에 한번씩 자동저장, checkpoint 삭제 안됨

    sess_config = tf.ConfigProto(
            log_device_placement=False,  # log_device_placement 작성하는동안 할당장치 알려줌.
            allow_soft_placement=True)  # allow_soft_placement False면 GPU없을때 오류남
    sess_config.gpu_options.allow_growth=True  # 탄력적으로 GPU메모리 사용

    # Train!
    #with tf.Session(config=sess_config) as sess:
    with tf.Session() as sess:  # with문 내의 모든 명령들은 CPU 혹은 GPU 사용 선언
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)  # summary 오퍼레이션이 평가된 결과 및 텐서보드 그래프를 파라미터 형식으로 log_dir 에 저장
            sess.run(tf.global_variables_initializer())  # 데이터셋이 로드되고 그래프가 모두 정의되면 변수를 초기화하여 훈련 시작

            if config.load_path:  # log의 설정 값들 경로를 지정하였다면
                # Restore from a checkpoint if the user requested it.
                restore_path = get_most_recent_checkpoint(config.model_dir)  # 가장 마지막에 저장된 파일경로 저장
                saver.restore(sess, restore_path)  # restore_path 값 가져오기
                log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)  # git과 slack을 이용한 log 출력
            elif config.initialize_path:  # log의 설정 값들로 초기화하여 사용하기로 지정하였다면
                restore_path = get_most_recent_checkpoint(config.initialize_path)  # 지정된 경로에서 가장 마지막에 저장된 파일경로 저장
                saver.restore(sess, restore_path)  # restore_path 값 가져오기
                log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)  # git과 slack을 이용한 log 출력

                zero_step_assign = tf.assign(global_step, 0)  # global_step의 텐서 객체 참조 변수 값을 0으로 바꿔주는 명령어 지정
                sess.run(zero_step_assign)  # 변수들을 모두 0으로 바꾸는 명령어 실행

                start_step = sess.run(global_step)  # global_step 값 부분을 시작지점으로 하여 연산 시작
                log('='*50)
                log(' [*] Global step is reset to {}'. \
                        format(start_step))  # 즉, 연산 시작 부분이 0으로 초기화 되었다고 알려줌.
                log('='*50)
            else:
                log('Starting new training run at commit: %s' % commit, slack=True)  # 과거의 데이터를 사용하지 않을 경우 새로운 학습이라고 log 출력

            start_step = sess.run(global_step)  # 연산 시작지점 가져오기

            train_feeder.start_in_session(sess, start_step)
            test_feeder.start_in_session(sess, start_step)

            while not coord.should_stop():  # 쓰레드가 멈춰야하는 상황이 아니라면
                start_time = time.time()  # 시작시간 지정(1970년 1월 1일 이후 경과된 시간을 UTC 기준으로 초로 반환)
                step, loss, opt = sess.run(
                        [global_step, model.loss_without_coeff, model.optimize],
                        feed_dict=model.get_dummy_feed_dict())  # step 값은 global_step 값으로 지정, loss 값은

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                        step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % config.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
                    raise Exception('Loss Exploded')

                if step % config.summary_interval == 0:
                    log('Writing summary at step: %d' % step)

                    feed_dict = {
                            **model.get_dummy_feed_dict(),
                            **test_model.get_dummy_feed_dict()
                    }
                    summary_writer.add_summary(sess.run(
                            test_stats, feed_dict=feed_dict), step)

                if step % config.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

                if step % config.test_interval == 0:
                    log('Saving audio and alignment...')
                    num_test = config.num_test

                    fetches = [
                            model.inputs[:num_test],
                            model.linear_outputs[:num_test],
                            model.alignments[:num_test],
                            test_model.inputs[:num_test],
                            test_model.linear_outputs[:num_test],
                            test_model.alignments[:num_test],
                    ]
                    feed_dict = {
                            **model.get_dummy_feed_dict(),
                            **test_model.get_dummy_feed_dict()
                    }

                    sequences, spectrograms, alignments, \
                            test_sequences, test_spectrograms, test_alignments = \
                                    sess.run(fetches, feed_dict=feed_dict)

                    save_and_plot(sequences[:1], spectrograms[:1], alignments[:1],
                            log_dir, step, loss, "train")
                    save_and_plot(test_sequences, test_spectrograms, test_alignments,
                            log_dir, step, loss, "test")

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 17
0
def train(log_dir, config):
    config.data_paths = config.data_paths

    data_dirs = [os.path.join(data_path, "data") \
            for data_path in config.data_paths]
    num_speakers = len(data_dirs)
    config.num_test = config.num_test_per_speaker * num_speakers

    if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]:
        raise Exception("[!] Unkown model_type for multi-speaker: {}".format(
            config.model_type))

    commit = get_git_commit() if config.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')

    #log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash())
    log('=' * 50)
    #log(' [*] dit diff:\n%s' % get_git_diff())
    log('=' * 50)
    log(' [*] Checkpoint path: %s' % checkpoint_path)
    log(' [*] Loading training data from: %s' % data_dirs)
    log(' [*] Using model: %s' % config.model_dir)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        train_feeder = DataFeeder(coord,
                                  data_dirs,
                                  hparams,
                                  config,
                                  32,
                                  data_type='train',
                                  batch_size=hparams.batch_size)
        test_feeder = DataFeeder(coord,
                                 data_dirs,
                                 hparams,
                                 config,
                                 8,
                                 data_type='test',
                                 batch_size=config.num_test)

    # Set up model:
    is_randomly_initialized = config.initialize_path is None
    global_step = tf.Variable(0, name='global_step', trainable=False)

    with tf.variable_scope('model') as scope:
        model = create_model(hparams)
        model.initialize(train_feeder.inputs,
                         train_feeder.input_lengths,
                         num_speakers,
                         train_feeder.speaker_id,
                         train_feeder.mel_targets,
                         train_feeder.linear_targets,
                         train_feeder.loss_coeff,
                         is_randomly_initialized=is_randomly_initialized)

        model.add_loss()
        model.add_optimizer(global_step)
        train_stats = add_stats(model, scope_name='stats')  # legacy

    with tf.variable_scope('model', reuse=True) as scope:
        test_model = create_model(hparams)
        test_model.initialize(test_feeder.inputs,
                              test_feeder.input_lengths,
                              num_speakers,
                              test_feeder.speaker_id,
                              test_feeder.mel_targets,
                              test_feeder.linear_targets,
                              test_feeder.loss_coeff,
                              rnn_decoder_test_mode=True,
                              is_randomly_initialized=is_randomly_initialized)
        test_model.add_loss()

    test_stats = add_stats(test_model, model, scope_name='test')
    test_stats = tf.summary.merge([test_stats, train_stats])

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2)

    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    # Train!
    #with tf.Session(config=sess_config) as sess:
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if config.load_path:
                # Restore from a checkpoint if the user requested it.
                restore_path = get_most_recent_checkpoint(config.model_dir)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            elif config.initialize_path:
                restore_path = get_most_recent_checkpoint(
                    config.initialize_path)
                saver.restore(sess, restore_path)
                log('Initialized from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)

                zero_step_assign = tf.assign(global_step, 0)
                sess.run(zero_step_assign)

                start_step = sess.run(global_step)
                log('=' * 50)
                log(' [*] Global step is reset to {}'. \
                        format(start_step))
                log('=' * 50)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            start_step = sess.run(global_step)

            train_feeder.start_in_session(sess, start_step)
            test_feeder.start_in_session(sess, start_step)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss_without_coeff, model.optimize],
                    feed_dict=model.get_dummy_feed_dict())

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % config.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % config.summary_interval == 0:
                    log('Writing summary at step: %d' % step)

                    feed_dict = {
                        **model.get_dummy_feed_dict(),
                        **test_model.get_dummy_feed_dict()
                    }
                    summary_writer.add_summary(
                        sess.run(test_stats, feed_dict=feed_dict), step)

                if step % config.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

                if step % config.test_interval == 0:
                    log('Saving audio and alignment...')
                    num_test = config.num_test

                    fetches = [
                        model.inputs[:num_test],
                        model.linear_outputs[:num_test],
                        model.alignments[:num_test],
                        test_model.inputs[:num_test],
                        test_model.linear_outputs[:num_test],
                        test_model.alignments[:num_test],
                    ]
                    feed_dict = {
                        **model.get_dummy_feed_dict(),
                        **test_model.get_dummy_feed_dict()
                    }

                    sequences, spectrograms, alignments, \
                            test_sequences, test_spectrograms, test_alignments = \
                                    sess.run(fetches, feed_dict=feed_dict)

                    save_and_plot(sequences[:1], spectrograms[:1],
                                  alignments[:1], log_dir, step, loss, "train")
                    save_and_plot(test_sequences, test_spectrograms,
                                  test_alignments, log_dir, step, loss, "test")

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 18
0
def validate(val_loader, model, device, mels_criterion, stop_criterion, writer,
             val_dir):
    batch_time = ValueWindow()
    losses = ValueWindow()

    # switch to evaluate mode
    model.eval()

    global global_epoch
    global global_step
    with torch.no_grad():
        end = time.time()
        for i, (txts, mels, stop_tokens, txt_lengths,
                mels_lengths) in enumerate(val_loader):
            # measure data loading time
            batch_time.update(time.time() - end)

            if device > -1:
                txts = txts.cuda(device)
                mels = mels.cuda(device)
                stop_tokens = stop_tokens.cuda(device)
                txt_lengths = txt_lengths.cuda(device)
                mels_lengths = mels_lengths.cuda(device)

            # compute output
            frames, decoder_frames, stop_tokens_predict, alignment = model(
                txts, txt_lengths, mels)
            decoder_frames_loss = mels_criterion(decoder_frames,
                                                 mels,
                                                 lengths=mels_lengths)
            frames_loss = mels_criterion(frames, mels, lengths=mels_lengths)
            stop_token_loss = stop_criterion(stop_tokens_predict,
                                             stop_tokens,
                                             lengths=mels_lengths)
            loss = decoder_frames_loss + frames_loss + stop_token_loss

            losses.update(loss.item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % hparams.print_freq == 0:
                log('Epoch: [{0}]\t'
                    'Test: [{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                        global_epoch,
                        i,
                        len(val_loader),
                        batch_time=batch_time,
                        loss=losses))
            # Logs
            writer.add_scalar("loss", float(loss.item()), global_step)
            writer.add_scalar(
                "avg_loss in {} window".format(losses.get_dinwow_size),
                float(losses.avg), global_step)
            writer.add_scalar("stop_token_loss", float(stop_token_loss.item()),
                              global_step)
            writer.add_scalar("decoder_frames_loss",
                              float(decoder_frames_loss.item()), global_step)
            writer.add_scalar("output_frames_loss", float(frames_loss.item()),
                              global_step)

        dst_alignment_path = join(val_dir,
                                  "{}_alignment.png".format(global_step))
        alignment = alignment.cpu().detach().numpy()
        plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]],
                       dst_alignment_path,
                       info="{}, {}".format(hparams.builder, global_step))

    return losses.avg
Ejemplo n.º 19
0
def train(log_dir, args):
	save_dir = os.path.join(log_dir, 'pretrained/')
	checkpoint_path = os.path.join(save_dir, 'model.ckpt')
	input_path = os.path.join(args.base_dir, args.input)
	plot_dir = os.path.join(log_dir, 'plots')
	os.makedirs(plot_dir, exist_ok=True)
	log('Checkpoint path: {}'.format(checkpoint_path))
	log('Loading training data from: {}'.format(input_path))
	log('Using model: {}'.format(args.model))
	log(hparams_debug_string())

	#Set up data feeder
	coord = tf.train.Coordinator()
	with tf.variable_scope('datafeeder') as scope:
		feeder = Feeder(coord, input_path, hparams)

	#Set up model:
	step_count = 0
	try:
		#simple text file to keep count of global step
		with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file:
			step_count = int(file.read())
	except:
		print('no step_counter file found, assuming there is no saved checkpoint')

	global_step = tf.Variable(step_count, name='global_step', trainable=False)
	with tf.variable_scope('model') as scope:
		model = create_model(args.model, hparams)
		model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets)
		model.add_loss()
		model.add_optimizer(global_step)
		stats = add_stats(model)

	#Book keeping
	step = 0
	time_window = ValueWindow(100)
	loss_window = ValueWindow(100)
	saver = tf.train.Saver(max_to_keep=5)

	#Memory allocation on the GPU as needed
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True

	#Train
	with tf.Session(config=config) as sess:
		try:
			summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
			sess.run(tf.global_variables_initializer())

			#saved model restoring
			if args.restore:
				#Restore saved model if the user requested it, Default = True.
				try:
					checkpoint_state = tf.train.get_checkpoint_state(save_dir)
				except tf.errors.OutOfRangeError as e:
					log('Cannot restore checkpoint: {}'.format(e))

			if (checkpoint_state and checkpoint_state.model_checkpoint_path):
				log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path))
				saver.restore(sess, checkpoint_state.model_checkpoint_path)

			else:
				if not args.restore:
					log('Starting new training!')
				else:
					log('No model to load at {}'.format(save_dir))

			#initiating feeder
			feeder.start_in_session(sess)

			#Training loop
			while not coord.should_stop():
				start_time = time.time()
				step, loss, opt = sess.run([global_step, model.loss, model.optimize])
				time_window.append(time.time() - start_time)
				loss_window.append(loss)
				message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
					step, time_window.average, loss, loss_window.average)
				log(message, end='\r')

				if loss > 100 or np.isnan(loss):
					log('Loss exploded to {:.5f} at step {}'.format(loss, step))
					raise Exception('Loss exploded')

				if step % args.summary_interval == 0:
					log('\nWriting summary at step: {}'.format(step))
					summary_writer.add_summary(sess.run(stats), step)
				
				if step % args.checkpoint_interval == 0:
					with open(os.path.join(log_dir,'step_counter.txt'), 'w') as file:
						file.write(str(step))
					log('Saving checkpoint to: {}-{}'.format(checkpoint_path, step))
					saver.save(sess, checkpoint_path, global_step=step)
					# Unlike the original tacotron, we won't save audio
					# because we yet have to use wavenet as vocoder
					log('Saving alignement and Mel-Spectrograms..')
					input_seq, prediction, alignment, target = sess.run([model.inputs[0],
							 model.mel_outputs[0],
							 model.alignments[0],
							 model.mel_targets[0],
							 ])
					#save predicted spectrogram to disk (for plot and manual evaluation purposes)
					mel_filename = 'ljspeech-mel-prediction-step-{}.npy'.format(step)
					np.save(os.path.join(log_dir, mel_filename), prediction, allow_pickle=False)

					#save alignment plot to disk (control purposes)
					plot.plot_alignment(alignment, os.path.join(plot_dir, 'step-{}-align.png'.format(step)),
						info='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss))
					#save real mel-spectrogram plot to disk (control purposes)
					plot.plot_spectrogram(target, os.path.join(plot_dir, 'step-{}-real-mel-spectrogram.png'.format(step)),
						info='{}, {}, step={}, Real'.format(args.model, time_string(), step, loss))
					#save predicted mel-spectrogram plot to disk (control purposes)
					plot.plot_spectrogram(prediction, os.path.join(plot_dir, 'step-{}-pred-mel-spectrogram.png'.format(step)),
						info='{}, {}, step={}, loss={:.5}'.format(args.model, time_string(), step, loss))
					log('Input at step {}: {}'.format(step, sequence_to_text(input_seq)))

		except Exception as e:
			log('Exiting due to exception: {}'.format(e), slack=True)
			traceback.print_exc()
			coord.request_stop(e)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/hparams.yaml')
    parser.add_argument('-load_model', type=str, default=None)
    parser.add_argument('-model_name', type=str, default='P_S_Transformer_debug',
                        help='model name')
    # parser.add_argument('-batches_per_allreduce', type=int, default=1,
    #                     help='number of batches processed locally before '
    #                          'executing allreduce across workers; it multiplies '
    #                          'total batch size.')
    parser.add_argument('-num_wokers', type=int, default=0,
                        help='how many subprocesses to use for data loading. '
                             '0 means that the data will be loaded in the main process')
    parser.add_argument('-log', type=str, default='train.log')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile,Loader=yaml.FullLoader))

    log_name = opt.model_name or config.model.name
    log_folder = os.path.join(os.getcwd(),'logdir/logging',log_name)
    if not os.path.isdir(log_folder):
        os.mkdir(log_folder)
    logger = init_logger(log_folder+'/'+opt.log)

    # TODO: build dataloader
    train_datafeeder = DataFeeder(config,'debug')

    # TODO: build model or load pre-trained model
    global global_step
    global_step = 0
    learning_rate = CustomSchedule(config.model.d_model)
    # learning_rate = 0.00002
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=config.optimizer.beta1, beta_2=config.optimizer.beta2,
                                         epsilon=config.optimizer.epsilon)
    logger.info('config.optimizer.beta1:' + str(config.optimizer.beta1))
    logger.info('config.optimizer.beta2:' + str(config.optimizer.beta2))
    logger.info('config.optimizer.epsilon:' + str(config.optimizer.epsilon))
    # print(str(config))
    model = Speech_transformer(config=config,logger=logger)

    #Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every n epochs.
    checkpoint_path = log_folder
    ckpt = tf.train.Checkpoint(transformer=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        logger.info('Latest checkpoint restored!!')
    else:
        logger.info('Start new run')


    # define metrics and summary writer
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    # summary_writer = tf.keras.callbacks.TensorBoard(log_dir=log_folder)
    summary_writer = summary_ops_v2.create_file_writer_v2(log_folder+'/train')


    # @tf.function
    def train_step(batch_data):
        inp = batch_data['the_inputs'] # batch*time*feature
        tar = batch_data['the_labels'] # batch*time
        # inp_len = batch_data['input_length']
        # tar_len = batch_data['label_length']
        gtruth = batch_data['ground_truth']
        tar_inp = tar
        tar_real = gtruth
        # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp[:,:,0], tar_inp)
        combined_mask = create_combined_mask(tar=tar_inp)
        with tf.GradientTape() as tape:
            predictions, _ = model(inp, tar_inp, True, None,
                                   combined_mask, None)
            # logger.info('config.train.label_smoothing_epsilon:' + str(config.train.label_smoothing_epsilon))
            loss = LableSmoothingLoss(tar_real, predictions,config.model.vocab_size,config.train.label_smoothing_epsilon)
        gradients = tape.gradient(loss, model.trainable_variables)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(loss)
        train_accuracy(tar_real, predictions)

    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    acc_window = ValueWindow(100)
    logger.info('config.train.epoches:' + str(config.train.epoches))
    first_time = True
    for epoch in range(config.train.epoches):
        logger.info('start epoch '+ str(epoch))
        logger.info('total wavs: '+ str(len(train_datafeeder)))
        logger.info('batch size: ' + str(train_datafeeder.batch_size))
        logger.info('batch per epoch: ' + str(len(train_datafeeder)//train_datafeeder.batch_size))
        train_data = train_datafeeder.get_batch()
        start_time = time.time()
        train_loss.reset_states()
        train_accuracy.reset_states()

        for step in range(len(train_datafeeder)//train_datafeeder.batch_size):
            batch_data = next(train_data)
            step_time = time.time()
            train_step(batch_data)
            if first_time:
                model.summary()
                first_time=False
            time_window.append(time.time()-step_time)
            loss_window.append(train_loss.result())
            acc_window.append(train_accuracy.result())
            message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f, acc=%.05f, avg_acc=%.05f]' % (
                    global_step, time_window.average, train_loss.result(), loss_window.average, train_accuracy.result(),acc_window.average)
            logger.info(message)

            if global_step % 10 == 0:
                with summary_ops_v2.always_record_summaries():
                    with summary_writer.as_default():
                        summary_ops_v2.scalar('train_loss', train_loss.result(), step=global_step)
                        summary_ops_v2.scalar('train_acc', train_accuracy.result(), step=global_step)

            global_step += 1

        ckpt_save_path = ckpt_manager.save()
        logger.info('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
        logger.info('Time taken for 1 epoch: {} secs\n'.format(time.time() - start_time))