Esempio n. 1
0
def main():
    def _str_to_bool(s):
        """Convert string to bool (in argparse context)."""
        if s.lower() not in ['true', 'false']:
            raise ValueError(
                'Argument needs to be a boolean, got {}'.format(s))
        return {'true': True, 'false': False}[s.lower()]

    parser = argparse.ArgumentParser(description='WaveNet example network')

    DATA_DIRECTORY = './data/kss,./data/son'
    parser.add_argument('--data_dir',
                        type=str,
                        default=DATA_DIRECTORY,
                        help='The directory containing the VCTK corpus.')

    LOGDIR = None
    #LOGDIR = './/logdir-wavenet//train//2018-12-21T22-58-10'

    parser.add_argument(
        '--logdir',
        type=str,
        default=LOGDIR,
        help=
        'Directory in which to store the logging information for TensorBoard. If the model already exists, it will restore the state and will continue training. Cannot use with --logdir_root and --restore_from.'
    )

    parser.add_argument(
        '--logdir_root',
        type=str,
        default=None,
        help=
        'Root directory to place the logging output and generated model. These are stored under the dated subdirectory of --logdir_root. Cannot use with --logdir.'
    )
    parser.add_argument(
        '--restore_from',
        type=str,
        default=None,
        help=
        'Directory in which to restore the model from. This creates the new model under the dated directory in --logdir_root. Cannot use with --logdir.'
    )

    CHECKPOINT_EVERY = 1000  # checkpoint 저장 주기
    parser.add_argument(
        '--checkpoint_every',
        type=int,
        default=CHECKPOINT_EVERY,
        help='How many steps to save each checkpoint after. Default: ' +
        str(CHECKPOINT_EVERY) + '.')

    config = parser.parse_args()  # command 창에서 입력받을 수 있는 조건
    config.data_dir = config.data_dir.split(",")

    try:
        directories = validate_directories(config, hparams)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    log_path = os.path.join(logdir, 'train.log')
    infolog.init(log_path, logdir)

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Create coordinator.
    coord = tf.train.Coordinator()
    num_speakers = len(config.data_dir)
    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = hparams.silence_threshold if hparams.silence_threshold > EPSILON else None
        gc_enable = num_speakers > 1

        # AudioReader에서 wav 파일을 잘라 input값을 만든다. receptive_field길이만큼을 앞부분에 pad하거나 앞조각에서 가져온다. (receptive_field+ sample_size)크기로 자른다.
        reader = DataFeederWavenet(
            coord,
            config.data_dir,
            batch_size=hparams.wavenet_batch_size,
            receptive_field=WaveNetModel.calculate_receptive_field(
                hparams.filter_width, hparams.dilations, hparams.scalar_input,
                hparams.initial_filter_width),
            gc_enable=gc_enable)
        if gc_enable:
            audio_batch, lc_batch, gc_id_batch = reader.inputs_wav, reader.local_condition, reader.speaker_id
        else:
            print("didn't work")
            #audio_batch, lc_batch = reader.inputs_wav, local_condition

    # Create network.
    net = WaveNetModel(
        batch_size=hparams.wavenet_batch_size,
        dilations=hparams.dilations,
        filter_width=hparams.filter_width,
        residual_channels=hparams.residual_channels,
        dilation_channels=hparams.dilation_channels,
        quantization_channels=hparams.quantization_channels,
        out_channels=hparams.out_channels,
        skip_channels=hparams.skip_channels,
        use_biases=hparams.use_biases,  #  True
        scalar_input=hparams.scalar_input,
        initial_filter_width=hparams.initial_filter_width,
        global_condition_channels=hparams.gc_channels,
        global_condition_cardinality=num_speakers,
        local_condition_channels=hparams.num_mels,
        upsample_factor=hparams.upsample_factor,
        train_mode=True)

    if hparams.l2_regularization_strength == 0:
        hparams.l2_regularization_strength = None

    net.add_loss(input_batch=audio_batch,
                 local_condition=lc_batch,
                 global_condition_batch=gc_id_batch,
                 l2_regularization_strength=hparams.l2_regularization_strength)
    net.add_optimizer(hparams, global_step)

    run_metadata = tf.RunMetadata()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)
                      )  # log_device_placement=False --> cpu/gpu 자동 배치.
    init = tf.global_variables_initializer()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(
        var_list=tf.global_variables(),
        max_to_keep=hparams.max_checkpoints)  # 최대 checkpoint 저장 갯수 지정

    try:
        start_step = load(saver, sess, restore_from)  # checkpoint load
        if is_overwritten_training or start_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            zero_step_assign = tf.assign(global_step, 0)
            sess.run(zero_step_assign)

    except:
        print(
            "Something went wrong while restoring checkpoint. We will terminate training to avoid accidentally overwriting the previous model."
        )
        raise

    ###########

    start_step = sess.run(global_step)
    last_saved_step = start_step
    try:
        reader.start_in_session(sess, start_step)
        while not coord.should_stop():

            start_time = time.time()
            if hparams.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                log('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                step, loss_value, _ = sess.run(
                    [global_step, net.loss, net.optimize],
                    options=run_options,
                    run_metadata=run_metadata)

                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                step, loss_value, _ = sess.run(
                    [global_step, net.loss, net.optimize])

            duration = time.time() - start_time
            log('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % config.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

            if step >= hparams.num_steps:
                # error message가 나오지만, 여기서 멈춘 것은 맞다.
                raise Exception('End xxx~~~yyy')

    except Exception as e:
        print('finally')
        #if step > last_saved_step:
        #    save(saver, sess, logdir, step)

        coord.request_stop(e)
Esempio n. 2
0
def main():
    def _str_to_bool(s):
        """Convert string to bool (in argparse context)."""
        if s.lower() not in ['true', 'false']:
            raise ValueError('Argument needs to be a boolean, got {}'.format(s))
        return {'true': True, 'false': False}[s.lower()]
    
    
    parser = argparse.ArgumentParser(description='WaveNet example network')
    
    DATA_DIRECTORY = '/home/kjm/Tacotron2-Wavenet-Korean-TTS/data/monika,/home/kjm/Tacotron2-Wavenet-Korean-TTS/data/kss'
    #DATA_DIRECTORY =  'D:\\hccho\\Tacotron-Wavenet-Vocoder-hccho\\data\\moon'
    parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY, help='The directory containing the VCTK corpus.')


    #LOGDIR = None
    LOGDIR = './logdir-wavenet/train/2021-03-10T02-58-23'

    parser.add_argument('--logdir', type=str, default=LOGDIR,help='Directory in which to store the logging information for TensorBoard. If the model already exists, it will restore the state and will continue training. Cannot use with --logdir_root and --restore_from.')
    
    
    parser.add_argument('--logdir_root', type=str, default=None,help='Root directory to place the logging output and generated model. These are stored under the dated subdirectory of --logdir_root. Cannot use with --logdir.')
    parser.add_argument('--restore_from', type=str, default=None,help='Directory in which to restore the model from. This creates the new model under the dated directory in --logdir_root. Cannot use with --logdir.')
    
    
    CHECKPOINT_EVERY = 1000   # checkpoint 저장 주기
    parser.add_argument('--checkpoint_every', type=int, default=CHECKPOINT_EVERY,help='How many steps to save each checkpoint after. Default: ' + str(CHECKPOINT_EVERY) + '.')
    
    
    parser.add_argument('--eval_every', type=int, default=1000,help='Steps between eval on test data')
    
   
    
    config = parser.parse_args()  # command 창에서 입력받을 수 있는 조건
    config.data_dir = config.data_dir.split(",")
    
    try:
        directories = validate_directories(config,hparams)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from


    log_path = os.path.join(logdir, 'train.log')
    infolog.init(log_path, logdir)


    global_step = tf.Variable(0, name='global_step', trainable=False)

    if hparams.l2_regularization_strength == 0:
        hparams.l2_regularization_strength = None


    # Create coordinator.
    coord = tf.train.Coordinator()
    num_speakers = len(config.data_dir)
    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = hparams.silence_threshold if hparams.silence_threshold > EPSILON else None
        gc_enable = True  # Before: num_speakers > 1    After: 항상 True
        
        # AudioReader에서 wav 파일을 잘라 input값을 만든다. receptive_field길이만큼을 앞부분에 pad하거나 앞조각에서 가져온다. (receptive_field+ sample_size)크기로 자른다.
        reader = DataFeederWavenet(coord,config.data_dir,batch_size=hparams.wavenet_batch_size,gc_enable= gc_enable,test_mode=False)
        
        # test를 위한 DataFeederWavenet를 하나 만들자. 여기서는 딱 1개의 파일만 가져온다.
        reader_test = DataFeederWavenet(coord,config.data_dir,batch_size=1,gc_enable= gc_enable,test_mode=True,queue_size=1)
        
        

        audio_batch, lc_batch, gc_id_batch = reader.inputs_wav, reader.local_condition, reader.speaker_id


    # Create train network.
    net = create_network(hparams,hparams.wavenet_batch_size,num_speakers,is_training=True)
    net.add_loss(input_batch=audio_batch,local_condition=lc_batch, global_condition_batch=gc_id_batch, l2_regularization_strength=hparams.l2_regularization_strength,upsample_type=hparams.upsample_type)
    net.add_optimizer(hparams,global_step)



    run_metadata = tf.RunMetadata()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))  # log_device_placement=False --> cpu/gpu 자동 배치.
    init = tf.global_variables_initializer()
    sess.run(init)
    
    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=hparams.max_checkpoints)  # 최대 checkpoint 저장 갯수 지정
    
    try:
        start_step = load(saver, sess, restore_from)  # checkpoint load
        if is_overwritten_training or start_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            zero_step_assign = tf.assign(global_step, 0)
            sess.run(zero_step_assign)
            start_step=0
    except:
        print("Something went wrong while restoring checkpoint. We will terminate training to avoid accidentally overwriting the previous model.")
        raise


    ###########

    reader.start_in_session(sess,start_step)
    reader_test.start_in_session(sess,start_step)
    
    ################### Create test network.  <---- Queue 생성 때문에, sess restore후 test network 생성
    net_test = create_network(hparams,1,num_speakers,is_training=False)
  
    if hparams.scalar_input:
        samples = tf.placeholder(tf.float32,shape=[net_test.batch_size,None])
        waveform = 2*np.random.rand(net_test.batch_size).reshape(net_test.batch_size,-1)-1
        
    else:
        samples = tf.placeholder(tf.int32,shape=[net_test.batch_size,None])  # samples: mu_law_encode로 변환된 것. one-hot으로 변환되기 전. (batch_size, 길이)
        waveform = np.random.randint(hparams.quantization_channels,size=net_test.batch_size).reshape(net_test.batch_size,-1)
    upsampled_local_condition = tf.placeholder(tf.float32,shape=[net_test.batch_size,hparams.num_mels])  
    
        

    speaker_id = tf.placeholder(tf.int32,shape=[net_test.batch_size])  
    next_sample = net_test.predict_proba_incremental(samples,upsampled_local_condition,speaker_id)  # Fast Wavenet Generation Algorithm-1611.09482 algorithm 적용

        
    sess.run(net_test.queue_initializer)
    



    # test를 위한 placeholder는 모두 3개: samples,speaker_id,upsampled_local_condition
    # test용 mel-spectrogram을 하나 뽑자. 그것을 고정하지 않으면, thread가 계속 돌아가면서 data를 읽어온다.  reader_test의 역할은 여기서 끝난다.

    mel_input_test, speaker_id_test = sess.run([reader_test.local_condition,reader_test.speaker_id])


    with tf.variable_scope('wavenet',reuse=tf.AUTO_REUSE):
        upsampled_local_condition_data = net_test.create_upsample(mel_input_test,upsample_type=hparams.upsample_type)
        upsampled_local_condition_data_ = sess.run(upsampled_local_condition_data)  # upsampled_local_condition_data_ 을 feed_dict로 placehoder인 upsampled_local_condition에 넣어준다.

    ######################################################
    
    
    start_step = sess.run(global_step)
    step = last_saved_step = start_step
    try:        
        
        while not coord.should_stop():
            
            start_time = time.time()
            if hparams.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                log('Storing metadata')
                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                step, loss_value, _ = sess.run([global_step, net.loss, net.optimize],options=run_options,run_metadata=run_metadata)

                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                step, loss_value, _ = sess.run([global_step,net.loss, net.optimize])

            duration = time.time() - start_time
            log('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(step, loss_value, duration))
            
            
            if step % config.checkpoint_every == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step
                
                
            if step % config.eval_every == 0:  # config.eval_every
                eval_step(sess,logdir,step,waveform,upsampled_local_condition_data_,speaker_id_test,mel_input_test,samples,speaker_id,upsampled_local_condition,next_sample)
            
            if step >= hparams.num_steps:
                # error message가 나오지만, 여기서 멈춘 것은 맞다.
                raise Exception('End xxx~~~yyy')
            
    except Exception as e:
        print('finally')
        log('Exiting due to exception: %s' % e, slack=True)
        #if step > last_saved_step:
        #    save(saver, sess, logdir, step)        
        traceback.print_exc()
        coord.request_stop(e)