예제 #1
0
def train(log_dir, args, hparams, use_hvd=False):
    if use_hvd:
        import horovod.tensorflow as hvd
        # Initialize Horovod.
        hvd.init()
    else:
        hvd = None
    eval_dir, eval_plot_dir, eval_wav_dir, meta_folder, plot_dir, save_dir, tensorboard_dir, wav_dir = init_dir(
        log_dir)

    checkpoint_path = os.path.join(save_dir, 'centaur_model.ckpt')
    input_path = os.path.join(args.base_dir, args.input_dir)

    log('Checkpoint path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder'):
        feeder = Feeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    model, stats = model_train_mode(feeder, hparams, global_step, hvd=hvd)
    eval_model = model_test_mode(feeder, hparams)

    # Embeddings metadata
    char_embedding_meta = os.path.join(meta_folder, 'CharacterEmbeddings.tsv')
    if not os.path.isfile(char_embedding_meta):
        with open(char_embedding_meta, 'w', encoding='utf-8') as f:
            for symbol in symbols:
                if symbol == ' ':
                    symbol = '\\s'  # For visual purposes, swap space with \s

                f.write('{}\n'.format(symbol))

    char_embedding_meta = char_embedding_meta.replace(log_dir, '..')
    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=2)

    log('Centaur training set to a maximum of {} steps'.format(
        args.train_steps))

    # Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True
    if use_hvd:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    # Train
    with tf.Session(config=config) as sess:
        try:

            # Init model and load weights
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            sess.run(tf.tables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                restore_model(saver, sess, global_step, save_dir,
                              checkpoint_path, args.reset_global_step)
            else:
                log('Starting new training!', slack=True)
                saver.save(sess, checkpoint_path, global_step=global_step)

            # initializing feeder
            start_step = sess.run(global_step)
            feeder.start_threads(sess, start_step=start_step)
            # Horovod bcast vars across workers
            if use_hvd:
                # Horovod: broadcast initial variable states from rank 0 to all other processes.
                # This is necessary to ensure consistent initialization of all workers when
                # training is started with random weights or restored from a checkpoint.
                bcast = hvd.broadcast_global_variables(0)
                bcast.run()
                log('Worker{}: Initialized'.format(hvd.rank()))
            # Training loop
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
            while not coord.should_stop() and step < args.train_steps:
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.train_op])
                if use_hvd:
                    main_process = hvd.rank() == 0
                if main_process:
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                        step, time_window.average, loss, loss_window.average)
                    log(message,
                        end='\r',
                        slack=(step % args.checkpoint_interval == 0))

                    if np.isnan(loss) or loss > 100.:
                        log('Loss exploded to {:.5f} at step {}'.format(
                            loss, step))
                        raise Exception('Loss exploded')

                    if step % args.summary_interval == 0:
                        log('\nWriting summary at step {}'.format(step))
                        summary_writer.add_summary(sess.run(stats), step)

                    if step % args.eval_interval == 0:
                        run_eval(args, eval_dir, eval_model, eval_plot_dir,
                                 eval_wav_dir, feeder, hparams, sess, step,
                                 summary_writer)

                    if step % args.checkpoint_interval == 0 or step == args.train_steps or step == 300:
                        save_current_model(args, checkpoint_path, global_step,
                                           hparams, loss, model, plot_dir,
                                           saver, sess, step, wav_dir)

                    if step % args.embedding_interval == 0 or step == args.train_steps or step == 1:
                        update_character_embedding(char_embedding_meta,
                                                   save_dir, summary_writer)

            log('Centaur training complete after {} global steps!'.format(
                args.train_steps),
                slack=True)
            return save_dir

        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
예제 #2
0
def main(args):
    eval_fn = os.path.join(args.model_dir, 'eval-detailed.txt')
    assert os.path.exists(args.model_dir), 'Model dir does not exist.'
    assert args.overwrite or not os.path.exists(
        eval_fn), 'Evaluation file already exists.'
    os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % args.gpu

    print '\n' + '=' * 30 + ' ARGUMENTS ' + '=' * 30
    params = myutils.load_params(args.model_dir)
    for k, v in params.__dict__.iteritems():
        print 'TRAIN | {}: {}'.format(k, v)
    for k, v in args.__dict__.iteritems():
        print 'EVAL | {}: {}'.format(k, v)
    sys.stdout.flush()

    DURATION = 0.1
    BATCH_SIZE = 16
    with tf.device('/cpu:0'), tf.variable_scope('feeder'):
        feeder = Feeder(params.db_dir,
                        subset_fn=args.subset_fn,
                        ambi_order=params.ambi_order,
                        audio_rate=params.audio_rate,
                        video_rate=params.video_rate,
                        context=params.context,
                        duration=DURATION,
                        return_video=VIDEO in params.encoders,
                        img_prep=myutils.img_prep_fcn(),
                        return_flow=FLOW in params.encoders,
                        frame_size=(224, 448),
                        queue_size=BATCH_SIZE * 5,
                        n_threads=4,
                        for_eval=True)
        batches = feeder.dequeue(BATCH_SIZE)

        ambix_batch = batches['ambix']
        video_batch = batches['video'] if VIDEO in params.encoders else None
        flow_batch = batches['flow'] if FLOW in params.encoders else None
        audio_mask_batch = batches['audio_mask']

        ss = int(params.audio_rate * params.context) / 2
        t = int(params.audio_rate * DURATION)
        audio_input = ambix_batch[:, :, :params.ambi_order**2]
        audio_target = ambix_batch[:, ss:ss + t, params.ambi_order**2:]

    print '\n' + '=' * 20 + ' MODEL ' + '=' * 20
    sys.stdout.flush()
    with tf.device('/gpu:0'):
        # Model
        num_sep = params.num_sep_tracks if params.separation != NO_SEPARATION else 1
        net_params = SptAudioGenParams(
            sep_num_tracks=num_sep,
            ctx_feats_fc_units=params.context_units,
            loc_fc_units=params.loc_units,
            sep_freq_mask_fc_units=params.freq_mask_units,
            sep_fft_window=params.fft_window)
        model = SptAudioGen(ambi_order=params.ambi_order,
                            audio_rate=params.audio_rate,
                            video_rate=params.video_rate,
                            context=params.context,
                            sample_duration=DURATION,
                            encoders=params.encoders,
                            separation=params.separation,
                            params=net_params)

        # Inference
        pred_t = model.inference_ops(audio=audio_input,
                                     video=video_batch,
                                     flow=flow_batch,
                                     is_training=False)

        # Losses and evaluation metrics
        with tf.variable_scope('metrics'):
            w_t = audio_input[:, ss:ss + t]
            _, stft_dist_ps, lsd_ps, mse_ps, snr_ps = model.evaluation_ops(
                pred_t,
                audio_target,
                w_t,
                mask_channels=audio_mask_batch[:, params.ambi_order**2:])
        # Loader
        vars2save = [
            v for v in tf.global_variables()
            if not v.op.name.startswith('metrics')
        ]
        saver = tf.train.Saver(vars2save)

    print '\n' + '=' * 30 + ' VARIABLES ' + '=' * 30
    model_vars = tf.global_variables()
    import numpy as np
    for v in model_vars:
        if 'Adam' in v.op.name.split('/')[-1]:
            continue
        print ' * {:50s} | {:20s} | {:7s} | {:10s}'.format(
            v.op.name, str(v.get_shape()), str(np.prod(v.get_shape())),
            str(v.dtype))

    print '\n' + '=' * 30 + ' EVALUATION ' + '=' * 30
    sys.stdout.flush()
    config = tf.ConfigProto(allow_soft_placement=True,
                            gpu_options=tf.GPUOptions(allow_growth=True))
    with tf.Session(config=config) as sess:
        print 'Loading model...'
        sess.run(model.init_ops)
        saver.restore(sess, tf.train.latest_checkpoint(args.model_dir))

        print 'Initializing data feeders...'
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess, coord)
        feeder.start_threads(sess)

        all_metrics = [
            'amplitude/predicted', 'amplitude/gt', 'mse/avg', 'mse/X', 'mse/Y',
            'mse/Z', 'stft/avg', 'stft/X', 'stft/Y', 'stft/Z', 'lsd/avg',
            'lsd/X', 'lsd/Y', 'lsd/Z', 'mel_lsd/avg', 'mel_lsd/X', 'mel_lsd/Y',
            'mel_lsd/Z', 'snr/avg', 'snr/X', 'snr/Y', 'snr/Z', 'env_mse/avg',
            'env_mse/X', 'env_mse/Y', 'env_mse/Z', 'emd/dir', 'emd/dir2'
        ]
        metrics = OrderedDict([(key, []) for key in all_metrics])
        sample_ids = []
        telapsed = deque(maxlen=20)

        print 'Start evaluation...'
        it = -1
        # run_options = tf.RunOptions(timeout_in_ms=60*1000)
        while True:
            it += 1
            if feeder.done(sess):
                break
            start_time = time.time()
            outs = sess.run([
                batches['id'], audio_mask_batch, w_t, audio_target, pred_t,
                stft_dist_ps, lsd_ps, mse_ps, snr_ps
            ])
            video_id, layout, mono, gt, pred = outs[:5]
            gt_m = np.concatenate(
                (mono, gt), axis=2) * layout[:, np.newaxis, :]
            pred_m = np.concatenate(
                (mono, pred), axis=2) * layout[:, np.newaxis, :]
            stft_dist, lsd, mse, snr = outs[5:]

            _env_time = 0.
            _emd_time = 0.
            _pow_time = 0.
            _lsd_time = 0.
            for smp in range(BATCH_SIZE):
                metrics['stft/avg'].append(np.mean(stft_dist[smp]))
                for i, ch in zip(range(3), 'YZX'):
                    metrics['stft/' + ch].append(stft_dist[smp, i])

                metrics['lsd/avg'].append(np.mean(lsd[smp]))
                for i, ch in zip(range(3), 'YZX'):
                    metrics['lsd/' + ch].append(lsd[smp, i])

                metrics['mse/avg'].append(np.mean(mse[smp]))
                for i, ch in zip(range(3), 'YZX'):
                    metrics['mse/' + ch].append(mse[smp, i])

                metrics['snr/avg'].append(np.nanmean(snr[smp]))
                for i, ch in zip(range(3), 'YZX'):
                    metrics['snr/' + ch].append(snr[smp, i])

                # Compute Mel LSD distance
                _t = time.time()
                mel_lsd = myutils.compute_lsd_dist(pred[smp], gt[smp],
                                                   params.audio_rate)
                metrics['mel_lsd/avg'].append(np.mean(mel_lsd))
                for i, ch in zip(range(3), 'YZX'):
                    metrics['mel_lsd/' + ch].append(mel_lsd[i])
                _lsd_time += (time.time() - _t)

                # Compute envelope distances
                _t = time.time()
                env_mse = myutils.compute_envelope_dist(pred[smp], gt[smp])
                metrics['env_mse/avg'].append(np.mean(env_mse))
                for i, ch in zip(range(3), 'YZX'):
                    metrics['env_mse/' + ch].append(env_mse[i])
                _env_time += (time.time() - _t)

                # Compute EMD (for speed, only compute emd over first 0.1s of every 1sec)
                _t = time.time()
                emd_dir, emd_dir2 = ambix_emd(pred_m[smp],
                                              gt_m[smp],
                                              model.snd_rate,
                                              ang_res=30)
                metrics['emd/dir'].append(emd_dir)
                metrics['emd/dir2'].append(emd_dir2)
                _emd_time += (time.time() - _t)

                # Compute chunk power
                _t = time.time()
                metrics['amplitude/gt'].append(np.abs(gt[smp]).max())
                metrics['amplitude/predicted'].append(np.abs(pred[smp]).max())
                _pow_time += (time.time() - _t)

                sample_ids.append(video_id[smp])

            telapsed.append(time.time() - start_time)
            #print '\nTotal:', telapsed[-1]
            #print 'Env:', _env_time
            #print 'LSD:', _lsd_time
            #print 'EMD:', _emd_time
            #print 'POW:', _pow_time

            if it % 100 == 0:
                # Store evaluation metrics
                with open(eval_fn, 'w') as f:
                    f.write('SampleID | {}\n'.format(' '.join(metrics.keys())))
                    for smp in range(len(sample_ids)):
                        f.write('{} | {}\n'.format(
                            sample_ids[smp], ' '.join(
                                [str(metrics[key][smp]) for key in metrics])))

            if it % 5 == 0:
                stats = OrderedDict([(m, np.mean(metrics[m]))
                                     for m in all_metrics])
                myutils.print_stats(stats.values(),
                                    stats.keys(),
                                    BATCH_SIZE,
                                    telapsed,
                                    it,
                                    tag='EVAL')
                sys.stdout.flush()

        # Print progress
        stats = OrderedDict([(m, np.mean(metrics[m])) for m in all_metrics])
        myutils.print_stats(stats.values(),
                            stats.keys(),
                            BATCH_SIZE,
                            telapsed,
                            it,
                            tag='EVAL')
        sys.stdout.flush()
        with open(eval_fn, 'w') as f:
            f.write('SampleID | {}\n'.format(' '.join(metrics.keys())))
            for smp in range(len(sample_ids)):
                f.write('{} | {}\n'.format(
                    sample_ids[smp],
                    ' '.join([str(metrics[key][smp]) for key in metrics])))

        print('\n' + '#' * 60)
        print('End of evaluation.')
예제 #3
0
def main(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    print('\n' + '=' * 30 + ' ARGUMENTS ' + '=' * 30)
    sys.stdout.flush()
    if args.resume:
        params = myutils.load_params(args.model_dir)
        args.encoders = params.encoders
        args.separation = params.separation
        args.ambi_order = params.ambi_order
        args.audio_rate = params.audio_rate
        args.video_rate = params.video_rate
        args.context = params.context
        args.sample_dur = params.sample_dur
    else:
        myutils.save_params(args)
    myutils.print_params(args)

    # Feeder
    min_t = min([args.context, args.sample_dur, 1. / args.video_rate])
    args.video_rate = int(1. / min_t)
    with tf.device('/cpu:0'), tf.variable_scope('feeder'):
        feeder = Feeder(args.db_dir,
                        subset_fn=args.subset_fn,
                        ambi_order=args.ambi_order,
                        audio_rate=args.audio_rate,
                        video_rate=args.video_rate,
                        context=args.context,
                        duration=args.sample_dur,
                        return_video=VIDEO in args.encoders,
                        img_prep=myutils.img_prep_fcn(),
                        return_flow=FLOW in args.encoders,
                        frame_size=(224, 448),
                        queue_size=args.batch_size * 5,
                        n_threads=4,
                        for_eval=False)

        batches = feeder.dequeue(args.batch_size)
        ambix_batch = batches['ambix']
        video_batch = batches['video'] if 'video' in args.encoders else None
        flow_batch = batches['flow'] if 'flow' in args.encoders else None
        audio_mask_batch = batches['audio_mask']

        t = int(args.audio_rate * args.sample_dur)
        ss = int(args.audio_rate * args.context) / 2
        n_chann_in = args.ambi_order**2
        audio_input = ambix_batch[:, :, :n_chann_in]
        audio_target = ambix_batch[:, ss:ss + t, n_chann_in:]

    print('\n' + '=' * 20 + ' MODEL ' + '=' * 20)
    sys.stdout.flush()
    with tf.device('/gpu:0'):
        # Model
        num_sep = args.num_sep_tracks if args.separation != NO_SEPARATION else 1
        params = SptAudioGenParams(sep_num_tracks=num_sep,
                                   ctx_feats_fc_units=args.context_units,
                                   loc_fc_units=args.loc_units,
                                   sep_freq_mask_fc_units=args.freq_mask_units,
                                   sep_fft_window=args.fft_window)
        model = SptAudioGen(ambi_order=args.ambi_order,
                            audio_rate=args.audio_rate,
                            video_rate=args.video_rate,
                            context=args.context,
                            sample_duration=args.sample_dur,
                            encoders=args.encoders,
                            separation=args.separation,
                            params=params)
        ambix_pred = model.inference_ops(audio=audio_input,
                                         video=video_batch,
                                         flow=flow_batch,
                                         is_training=True)

        # Losses and evaluation metrics
        print(audio_mask_batch)
        with tf.variable_scope('metrics'):
            metrics_t, _, _, _, _ = model.evaluation_ops(
                ambix_pred,
                audio_target,
                audio_input[:, ss:ss + t],
                mask_channels=audio_mask_batch[:, args.ambi_order**2:])

        step_t = tf.Variable(0, trainable=False, name='step')
        with tf.variable_scope('loss'):
            loss_t = model.loss_ops(metrics_t, step_t)
            losses_t = {l: loss_t[l] for l in loss_t}
            regularizers = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if regularizers and 'regularization' in losses_t:
                losses_t['regularization'] = tf.add_n(regularizers)
            losses_t['total_loss'] = tf.add_n(losses_t.values())

        # Optimizer
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.variable_scope('optimization') and tf.control_dependencies(
                update_ops):
            train_op, lr_t = myutils.optimize(losses_t['total_loss'], step_t,
                                              args)

        # Initialization
        rest_ops = model.init_ops
        init_op = [
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ]
        saver = tf.train.Saver(max_to_keep=1)

        # Tensorboard
        metrics_t['training_loss'] = losses_t['total_loss']
        metrics_t['queue'] = feeder.queue_state
        metrics_t['lr'] = lr_t
        myutils.add_scalar_summaries(metrics_t.values(), metrics_t.keys())
        summary_ops = tf.summary.merge(
            tf.get_collection(tf.GraphKeys.SUMMARIES))
        summary_writer = tf.summary.FileWriter(args.model_dir, flush_secs=30)
        #summary_writer.add_graph(tf.get_default_graph())

    print('\n' + '=' * 30 + ' VARIABLES ' + '=' * 30)
    model_vars = tf.global_variables()
    import numpy as np
    for v in model_vars:
        if 'Adam' in v.op.name.split('/')[-1]:
            continue
        print(' * {:50s} | {:20s} | {:7s} | {:10s}'.format(
            v.op.name, str(v.get_shape()), str(np.prod(v.get_shape())),
            str(v.dtype)))

    print('\n' + '=' * 30 + ' TRAINING ' + '=' * 30)
    sys.stdout.flush()
    config = tf.ConfigProto(allow_soft_placement=True,
                            gpu_options=tf.GPUOptions(allow_growth=True))
    with tf.Session(config=config) as sess:
        print('Initializing network...')
        sess.run(init_op)
        if rest_ops:
            sess.run(rest_ops)

        print('Initializing data feeders...')
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess, coord)
        feeder.start_threads(sess)

        tf.get_default_graph().finalize()

        # Restore model
        init_step = 0
        if args.resume:
            print('Restoring previously saved model...')
            ckpt = tf.train.latest_checkpoint(args.model_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                init_step = sess.run(step_t)

        try:
            print('Start training...')
            duration = deque(maxlen=20)
            for step in range(init_step, args.n_iters):
                start_time = time.time()
                if step % 20 != 0:
                    sess.run(train_op)
                else:
                    outs = sess.run(
                        [train_op, summary_ops, losses_t['total_loss']] +
                        losses_t.values() + metrics_t.values())
                    if math.isnan(outs[2]):
                        raise ValueError(
                            'Training produced a NaN metric or loss.')
                duration.append(time.time() - start_time)

                if step % 20 == 0:  # Print progress to terminal and tensorboard
                    myutils.print_stats(outs[3:],
                                        losses_t.keys() + metrics_t.keys(),
                                        args.batch_size,
                                        duration,
                                        step,
                                        tag='TRAIN')
                    summary_writer.add_summary(outs[1], step)
                    sys.stdout.flush()

                if step % 5000 == 0 and step != 0:  # Save checkpoint
                    saver.save(sess,
                               args.model_dir + '/model.ckpt',
                               global_step=step_t)
                    print('=' * 60 + '\nCheckpoint saved\n' + '=' * 60)

        except Exception, e:
            print(str(e))

        finally:
예제 #4
0
def main():

    # Instantiate Configs
    config = hparams

    # Directory Setting
    input_path = os.path.join(config.wavenet_input, 'map.txt')
    post_input_path = os.path.join(
        config.post_train_input,
        'map.txt')  #post #yk: post train의 경우 input path설정
    log_dir = config.log_dir
    save_dir = os.path.join(log_dir, 'wave_pretrained')
    plot_dir = os.path.join(log_dir, 'plots')
    wav_dir = os.path.join(log_dir, 'wavs')
    eval_dir = os.path.join(log_dir, 'eval-dir')
    eval_plot_dir = os.path.join(eval_dir, 'plots')
    eval_wav_dir = os.path.join(eval_dir, 'wavs')
    tensorboard_dir = os.path.join(log_dir, 'wavenet_events')
    meta_folder = os.path.join(log_dir, 'metas')
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(meta_folder, exist_ok=True)

    checkpoint_path = os.path.join(save_dir, 'wavenet_model.ckpt')
    input_path = os.path.join(config.base_dir, input_path)
    post_input_path = os.path.join(
        config.base_dir,
        post_input_path)  #post #yk: post train의 경우 input path설정
    log.info('Checkpoint_path: {}'.format(checkpoint_path))
    if config.spk_train_mode == True:  # spk-train mode                     #post
        log.info(
            'Loading spk-training data from: {}'.format(input_path))  #post
    elif config.post_train_mode == True:  # post-train mode                                                    #post
        log.info('Loading post-training data from: {}'.format(
            post_input_path))  #post
    else:  # train all vars                       #post
        log.info('Loading training data from: {}'.format(input_path))  #post
    log.info('Using model: {}'.format('WaveNet'))
    log.info(hparams_debug_string())

    #Start by setting a seed for repeatability
    tf.set_random_seed(config.wavenet_random_seed)

    #Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        if config.post_train_mode == True:  # post-train mode                                                    #post
            feeder = Feeder(coord, post_input_path, config.base_dir,
                            config)  #post
        else:  # train all vars                       #post #yk : feeder는 post-train이냐 아니냐에따라 다른 data load
            feeder = Feeder(coord, input_path, config.base_dir, config)  #post
    #Instantiate Model Class (Graphing)
    global_step = tf.Variable(0, name='global_step', trainable=False)
    model, stats = model_train_mode(feeder, config, global_step)
    eval_model = model_test_mode(feeder, config, global_step)  ##EVAL

    #Speaker Embeddings metadata
    if config.speakers_path is not None:
        speaker_embedding_meta = config.speakers_path

    else:
        speaker_embedding_meta = os.path.join(meta_folder,
                                              'SpeakerEmbeddings.tsv')
        if not os.path.isfile(speaker_embedding_meta):
            with open(speaker_embedding_meta, 'w', encoding='utf-8') as f:
                for speaker in config.speakers:
                    f.write('{}\n'.format(speaker))

        speaker_embedding_meta = speaker_embedding_meta.replace(log_dir, '..')

    #book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    sh_saver = create_shadow_saver(model, global_step)

    if config.post_train_mode == True:
        log.info('Wavenet post training set to a maximum of {} steps'.format(
            config.post_train_steps))
        train_steps = config.post_train_steps
    elif config.pre_train_mode == True:
        log.info('Wavenet pre training set to a maximum of {} steps'.format(
            config.pre_train_steps))
        train_steps = config.pre_train_steps
    elif config.spk_train_mode == True:
        log.info('Wavenet spk training set to a maximum of {} steps'.format(
            config.spk_train_steps))
        train_steps = config.spk_train_steps
    else:  # train all vars                              #post  #yk : post train max step따로 받도록 >> train_step으로 통일 (뒤에 train loop에서 config.wavenet_train_steps이던 것을 train_steps로 바)
        log.info('Wavenet training set to a maximum of {} steps'.format(
            config.wavenet_train_steps))
        train_steps = config.wavenet_train_steps

    #Memory allocation on the memory
    conf = tf.ConfigProto()
    conf.gpu_options.allow_growth = True
    conf.allow_soft_placement = True
    run_init = False

    #Train
    with tf.Session(config=conf) as sess:
        try:
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            #saved model restoring
            if config.restore == True:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    if (checkpoint_state
                            and checkpoint_state.model_checkpoint_path):
                        log.info('Loading checkpoint {}'.format(
                            checkpoint_state.model_checkpoint_path))
                        load_averaged_model(
                            sess, sh_saver,
                            checkpoint_state.model_checkpoint_path)
                    else:
                        log.info('No model to load at {}'.format(save_dir))
                        if config.wavenet_weight_normalization:
                            run_init = True

                except tf.errors.OutOfRangeError as e:
                    log.info('Cannot restore checkpoint: {}'.format(e))
            else:
                log.info('Starting new training!')
                if config.wavenet_weight_normalization:
                    run_init = True

            if run_init:
                log.info(
                    '\nApplying Weight normalization in fresh training. Applying data dependent initialization forward pass..'
                )
                #Create init_model
                init_model, _ = model_train_mode(feeder,
                                                 config,
                                                 global_step,
                                                 init=True)

            #initializing feeder
            feeder.start_threads(sess)

            if run_init:
                #Run one forward pass for model parameters initialization (make prediction on init_batch)
                _ = sess.run(init_model.tower_y_hat)
                log.info(
                    'Data dependent initialization done. Starting training!')

            #Training loop
            while not coord.should_stop() and step < train_steps:
                start_time = time.time()

                step, loss, vq_loss, vq_perplexity, reconst_loss, spk_loss, opt = sess.run(
                    [
                        global_step, model.loss, model.vq_loss,
                        model.vq_perplexity, model.reconst_loss,
                        model.spk_loss, model.optimize
                    ])

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}, vq_loss={:.5f}, vq_perplexity={:.5f}, reconst_loss={:.5f}, spk_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average,
                    vq_loss, vq_perplexity, reconst_loss, spk_loss)
                log.info(message)

                if np.isnan(loss) or loss > 10000:

                    log.info('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')

                if step % config.summary_interval == 0:
                    log.info('\nWriting summary at step {}'.format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % config.checkpoint_interval == 0 or step == train_steps:
                    save_log(sess,
                             step,
                             model,
                             plot_dir,
                             wav_dir,
                             config=config)
                    save_checkpoint(sess, sh_saver, checkpoint_path,
                                    global_step)

                if step % config.eval_interval == 0:
                    log.info('\nEvaluating at step {}'.format(step))
                    eval_step(sess,
                              step,
                              eval_model,
                              eval_plot_dir,
                              eval_wav_dir,
                              summary_writer=summary_writer,
                              config=model._config)  ##EVAL

                if config.gin_channels > 0 and (
                        step % config.embedding_interval == 0
                        or step == train_steps):  #or step == 1):
                    #Get current checkpoint state
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)
                    print("checkpoint_state : {}".format(checkpoint_state))

                    #Update Projector
                    log.info(
                        '\nSaving Model Speaker Embeddings visualization..')
                    add_embedding_stats(summary_writer,
                                        [model.embedding_table.name],
                                        [speaker_embedding_meta],
                                        checkpoint_state.model_checkpoint_path)
                    log.info(
                        'WaveNet Speaker embeddings have been updated on tensorboard!'
                    )

            log.info('Wavenet training complete after {} global steps'.format(
                train_steps))
            return save_dir

        except Exception as e:
            log.info('Exiting due to exception: {}'.format(e))
            traceback.print_exc()
            coord.request_stop(e)