예제 #1
0
def train(log_dir, args, hparams, input_path):
    save_dir = os.path.join(log_dir, 'wave_pretrained')
    plot_dir = os.path.join(log_dir, 'plots')
    wav_dir = os.path.join(log_dir, 'wavs')
    eval_dir = os.path.join(log_dir, 'eval-dir')
    eval_plot_dir = os.path.join(eval_dir, 'plots')
    eval_wav_dir = os.path.join(eval_dir, 'wavs')
    tensorboard_dir = os.path.join(log_dir, 'wavenet_events')
    meta_folder = os.path.join(log_dir, 'metas')
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(meta_folder, exist_ok=True)

    checkpoint_path = os.path.join(save_dir, 'wavenet_model.ckpt')
    input_path = os.path.join(args.base_dir, input_path)

    log('Checkpoint_path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.wavenet_random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = Feeder(coord, input_path, args.base_dir, hparams)

    # Set up model
    global_step = tf.Variable(0, name='global_step', trainable=False)
    model, stats = model_train_mode(args, feeder, hparams, global_step)
    eval_model = model_test_mode(args, feeder, hparams, global_step)

    # Speaker Embeddings metadata
    if hparams.speakers_path is not None:
        speaker_embedding_meta = hparams.speakers_path

    else:
        speaker_embedding_meta = os.path.join(meta_folder, 'SpeakerEmbeddings.tsv')
        if not os.path.isfile(speaker_embedding_meta):
            with open(speaker_embedding_meta, 'w', encoding='utf-8') as f:
                for speaker in hparams.speakers:
                    f.write('{}\n'.format(speaker))

        speaker_embedding_meta = speaker_embedding_meta.replace(log_dir, '..')

    # book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    sh_saver = create_shadow_saver(model, global_step)

    log('Wavenet training set to a maximum of {} steps'.format(args.wavenet_train_steps))

    # Memory allocation on the memory
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    run_init = False

    # Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    if (checkpoint_state and checkpoint_state.model_checkpoint_path):
                        log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path), slack=True)
                        load_averaged_model(sess, sh_saver, checkpoint_state.model_checkpoint_path)
                    else:
                        log('No model to load at {}'.format(save_dir), slack=True)
                        if hparams.wavenet_weight_normalization:
                            run_init = True

                except tf.errors.OutOfRangeError as e:
                    log('Cannot restore checkpoint: {}'.format(e), slack=True)
            else:
                log('Starting new training!', slack=True)
                if hparams.wavenet_weight_normalization:
                    run_init = True

            if run_init:
                log(
                    '\nApplying Weight normalization in fresh training. Applying data dependent initialization forward pass..')
                # Create init_model
                init_model, _ = model_train_mode(args, feeder, hparams, global_step, init=True)

            # initializing feeder
            feeder.start_threads(sess)

            if run_init:
                # Run one forward pass for model parameters initialization (make prediction on init_batch)
                _ = sess.run(init_model.tower_y_hat)
                log('Data dependent initialization done. Starting training!')

            # Training loop
            while not coord.should_stop() and step < args.wavenet_num_steps:
                start_time = time.time()
                step, loss, opt = sess.run([global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average)
                log(message, end='\r', slack=(step % args.checkpoint_interval == 0))

                if np.isnan(loss) or loss > 100:
                    log('Loss exploded to {:.5f} at step {}'.format(loss, step))
                    raise Exception('Loss exploded')

                if step % args.summary_interval == 0:
                    log('\nWriting summary at step {}'.format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0 or step == args.wavenet_train_steps:
                    save_log(sess, step, model, plot_dir, wav_dir, hparams=hparams, model_name=args.model)
                    save_checkpoint(sess, sh_saver, checkpoint_path, global_step)

                if step % args.eval_interval == 0:
                    log('\nEvaluating at step {}'.format(step))
                    eval_step(sess, step, eval_model, eval_plot_dir, eval_wav_dir, summary_writer=summary_writer,
                              hparams=model._hparams, model_name=args.model)

                if hparams.gin_channels > 0 and (
                        step % args.embedding_interval == 0 or step == args.wavenet_train_steps or step == 1):
                    # Get current checkpoint state
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    # Update Projector
                    log('\nSaving Model Speaker Embeddings visualization..')
                    add_embedding_stats(summary_writer, [model.embedding_table.name], [speaker_embedding_meta],
                                        checkpoint_state.model_checkpoint_path)
                    log('WaveNet Speaker embeddings have been updated on tensorboard!')

            log('Wavenet training complete after {} global steps'.format(args.wavenet_train_steps), slack=True)
            return save_dir

        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
예제 #2
0
def train(log_dir, args, hparams, input_path):
    save_dir = os.path.join(log_dir, 'wave_pretrained')
    plot_dir = os.path.join(log_dir, 'plots')
    wav_dir = os.path.join(log_dir, 'wavs')
    eval_dir = os.path.join(log_dir, 'eval-dir')
    eval_plot_dir = os.path.join(eval_dir, 'plots')
    eval_wav_dir = os.path.join(eval_dir, 'wavs')
    tensorboard_dir = os.path.join(log_dir, 'wavenet_events')
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)

    ### load check point
    checkpoint_path = os.path.join(save_dir, 'wavenet_model.ckpt')
    input_path = os.path.join(args.base_dir, input_path)

    log('Checkpoint_path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.wavenet_random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = Feeder(coord, input_path, args.base_dir, hparams)

    # Set up model
    training_step = tf.Variable(0, name='global_step', trainable=False)
    model, stats = model_train_mode(args, feeder, hparams, training_step)
    eval_model = model_test_mode(args, feeder, hparams, training_step)

    # Calculating loss and executed time
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    sh_saver = create_shadow_saver(model, training_step)
    log('wavenet training set to a maximum of {} steps'.format(
        args.wavenet_train_steps),
        end=
        '\n==================================================================\n'
        )

    # Memory allocation on the memory
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        try:
            ###initialize variables
            sess.run(tf.global_variables_initializer())
            #### restore model from checkpoint
            if args.restore:
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)
                    if (checkpoint_state
                            and checkpoint_state.model_checkpoint_path):
                        log('Loadding checkpoint {}'.format(
                            checkpoint_state.model_checkpoint_path),
                            slack=True)
                        load_averaged_model(
                            sess, sh_saver,
                            checkpoint_state.model_checkpoint_path)
                except tf.errors.OutOfRangeError as e:
                    log('Cannot restore checkpoint: {}'.format(e), slack=True)
            else:
                log('Starting new training..', slack=True)
            ### start Feeder thread from session
            feeder.start_threads(sess)

            #### looping over epochs (training steps)
            while not coord.should_stop() and step < args.wavenet_train_steps:
                ###Save current time (to calculate executed time)
                start_time = time.time()
                step, y_hat, loss, opt = sess.run(
                    [training_step, model.y_hat, model.loss, model.optimize])
                #### add executed time to time window.
                time_window.append(time.time() - start_time)
                ### add loss to loss window
                loss_window.append(loss)

                #### print info to console
                message = 'Step = {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average)
                log(message,
                    end='\r',
                    slack=(step % args.checkpoint_interval == 0))

                ###### exit if loss exploded
                if loss > 100 or np.isnan(loss):
                    log('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')
                #### save checkpoint when meet checkpoint interval
                if step % args.checkpoint_interval == 0 or step == args.wavenet_train_steps:
                    save_log(sess,
                             step,
                             model,
                             plot_dir,
                             wav_dir,
                             hparams=hparams)
                    save_checkpoint(sess, sh_saver, checkpoint_path,
                                    training_step)
                ### save inference result when meed inference interval
                if step % args.eval_interval == 0:
                    log('Evaluating at step {}'.format(step))
                    eval_step(sess,
                              step,
                              eval_model,
                              eval_plot_dir,
                              eval_wav_dir,
                              summary_writer=None,
                              hparams=model._hparams)

            log('wavenet training complete after {} global steps'.format(
                args.wavenet_train_steps),
                slack=True)
            return save_dir
        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
            ### close data feeder object to free memory
            coord.request_stop(e)
예제 #3
0
def train(log_dir, args, hparams, input_path):
    save_dir = os.path.join(log_dir, 'wave_pretrained/')
    eval_dir = os.path.join(log_dir, 'eval-dir')
    audio_dir = os.path.join(log_dir, 'wavs')
    plot_dir = os.path.join(log_dir, 'plots')
    wav_dir = os.path.join(log_dir, 'wavs')
    eval_audio_dir = os.path.join(eval_dir, 'wavs')
    eval_plot_dir = os.path.join(eval_dir, 'plots')
    checkpoint_path = os.path.join(save_dir, 'wavenet_model.ckpt')
    input_path = os.path.join(args.base_dir, input_path)
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(audio_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(eval_audio_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)

    log('Checkpoint_path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    #Start by setting a seed for repeatability
    tf.set_random_seed(hparams.wavenet_random_seed)

    #Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = Feeder(coord, input_path, args.base_dir, hparams)

    #Set up model
    global_step = tf.Variable(0, name='global_step', trainable=False)
    model, stats = model_train_mode(args, feeder, hparams, global_step)
    eval_model = model_test_mode(args, feeder, hparams, global_step)

    #book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    sh_saver = create_shadow_saver(model, global_step)

    log('Wavenet training set to a maximum of {} steps'.format(
        args.wavenet_train_steps))

    #Memory allocation on the memory
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    #Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            #saved model restoring
            if args.restore:
                #Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)
                except tf.errors.OutOfRangeError as e:
                    log('Cannot restore checkpoint: {}'.format(e))

            if (checkpoint_state and checkpoint_state.model_checkpoint_path):
                log('Loading checkpoint {}'.format(
                    checkpoint_state.model_checkpoint_path))
                load_averaged_model(sess, sh_saver,
                                    checkpoint_state.model_checkpoint_path)

            else:
                if not args.restore:
                    log('Starting new training!')
                else:
                    log('No model to load at {}'.format(save_dir))

            #initializing feeder
            feeder.start_threads(sess)

            #Training loop
            while not coord.should_stop() and step < args.wavenet_train_steps:
                start_time = time.time()
                step, y_hat, loss, opt = sess.run(
                    [global_step, model.y_hat, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average)
                log(message, end='\r')

                if loss > 100 or np.isnan(loss):
                    log('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')

                if step % args.summary_interval == 0:
                    log('\nWriting summary at step {}'.format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0 or step == args.wavenet_train_steps:
                    save_log(sess,
                             step,
                             model,
                             plot_dir,
                             audio_dir,
                             hparams=hparams)
                    save_checkpoint(sess, sh_saver, checkpoint_path,
                                    global_step)

                if step % args.eval_interval == 0:
                    log('\nEvaluating at step {}'.format(step))
                    eval_step(sess,
                              step,
                              eval_model,
                              eval_plot_dir,
                              eval_audio_dir,
                              summary_writer=summary_writer,
                              hparams=model._hparams)

            log('Wavenet training complete after {} global steps'.format(
                args.wavenet_train_steps))
            return save_dir

        except Exception as e:
            log('Exiting due to Exception: {}'.format(e))
예제 #4
0
def train(log_dir, args, hparams, input_path):
	save_dir = os.path.join(log_dir, 'wave_pretrained/')
	eval_dir = os.path.join(log_dir, 'eval-dir')
	audio_dir = os.path.join(log_dir, 'wavs')
	plot_dir = os.path.join(log_dir, 'plots')
	wav_dir = os.path.join(log_dir, 'wavs')
	eval_audio_dir = os.path.join(eval_dir, 'wavs')
	eval_plot_dir = os.path.join(eval_dir, 'plots')
	checkpoint_path = os.path.join(save_dir, 'wavenet_model.ckpt')
	input_path = os.path.join(args.base_dir, input_path)
	os.makedirs(save_dir, exist_ok=True)
	os.makedirs(wav_dir, exist_ok=True)
	os.makedirs(audio_dir, exist_ok=True)
	os.makedirs(plot_dir, exist_ok=True)
	os.makedirs(eval_audio_dir, exist_ok=True)
	os.makedirs(eval_plot_dir, exist_ok=True)

	log('Checkpoint_path: {}'.format(checkpoint_path))
	log('Loading training data from: {}'.format(input_path))
	log('Using model: {}'.format(args.model))
	log(hparams_debug_string())

	#Start by setting a seed for repeatability
	tf.set_random_seed(hparams.wavenet_random_seed)

	#Set up data feeder
	coord = tf.train.Coordinator()
	with tf.variable_scope('datafeeder') as scope:
		feeder = Feeder(coord, input_path, args.base_dir, hparams)

	#Set up model
	global_step = tf.Variable(0, name='global_step', trainable=False)
	model, stats = model_train_mode(args, feeder, hparams, global_step)
	eval_model = model_test_mode(args, feeder, hparams, global_step)

	#book keeping
	step = 0
	time_window = ValueWindow(100)
	loss_window = ValueWindow(100)
	sh_saver = create_shadow_saver(model, global_step)

	log('Wavenet training set to a maximum of {} steps'.format(args.wavenet_train_steps))

	#Memory allocation on the memory
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True

	#Train
	with tf.Session(config=config) as sess:
		try:
			summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
			sess.run(tf.global_variables_initializer())

			#saved model restoring
			if args.restore:
				#Restore saved model if the user requested it, default = True
				try:
					checkpoint_state = tf.train.get_checkpoint_state(save_dir)
				except tf.errors.OutOfRangeError as e:
					log('Cannot restore checkpoint: {}'.format(e))

			if (checkpoint_state and checkpoint_state.model_checkpoint_path):
				log('Loading checkpoint {}'.format(checkpoint_state.model_checkpoint_path))
				load_averaged_model(sess, sh_saver, checkpoint_state.model_checkpoint_path)

			else:
				if not args.restore:
					log('Starting new training!')
				else:
					log('No model to load at {}'.format(save_dir))

			#initializing feeder
			feeder.start_threads(sess)

			#Training loop
			while not coord.should_stop() and step < args.wavenet_train_steps:
				start_time = time.time()
				step, y_hat, loss, opt = sess.run([global_step, model.y_hat, model.loss, model.optimize])
				time_window.append(time.time() - start_time)
				loss_window.append(loss)

				message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
					step, time_window.average, loss, loss_window.average)
				log(message, end='\r')

				if loss > 100 or np.isnan(loss):
					log('Loss exploded to {:.5f} at step {}'.format(loss, step))
					raise Exception('Loss exploded')

				if step % args.summary_interval == 0:
					log('\nWriting summary at step {}'.format(step))
					summary_writer.add_summary(sess.run(stats), step)

				if step % args.checkpoint_interval == 0:
					save_log(sess, step, model, plot_dir, audio_dir, hparams=hparams)
					save_checkpoint(sess, sh_saver, checkpoint_path, global_step)

				if step % args.eval_interval == 0:
					log('\nEvaluating at step {}'.format(step))
					eval_step(sess, step, eval_model, eval_plot_dir, eval_audio_dir, summary_writer=summary_writer , hparams=model._hparams)

			log('Wavenet training complete after {} global steps'.format(args.wavenet_train_steps))

		except Exception as e:
			log('Exiting due to Exception: {}'.format(e))