Ejemplo n.º 1
0
def model_test_mode(args, feeder, hparams, global_step):
	with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
		model_name = None
		if args.model in ('Tacotron-2', 'Both'):
			model_name = 'Tacotron'
		model = create_model(model_name or args.model, hparams)
		if hparams.predict_linear:
			model.initialize(feeder.eval_inputs, feeder.eval_input_lengths, feeder.eval_mel_targets, feeder.eval_token_targets, 
				linear_targets=feeder.eval_linear_targets, targets_lengths=feeder.eval_targets_lengths, global_step=global_step,
				is_training=False, is_evaluating=True)
		else:
			model.initialize(feeder.eval_inputs, feeder.eval_input_lengths, feeder.eval_mel_targets, feeder.eval_token_targets, 
				targets_lengths=feeder.eval_targets_lengths, global_step=global_step, is_training=False, is_evaluating=True)
		model.add_loss()
		return model
Ejemplo n.º 2
0
def model_train_mode(args, feeder, hparams, global_step):
    with tf.variable_scope('Duration_model', reuse=tf.AUTO_REUSE) as scope:
        model_name = 'Duration'
        model = create_model(model_name or args.model, hparams)
        model.initialize(feeder.inputs_phoneme,
                         feeder.inputs_type,
                         feeder.inputs_time,
                         feeder.input_lengths,
                         feeder.duration_targets,
                         targets_lengths=feeder.targets_lengths,
                         global_step=global_step,
                         is_training=True,
                         split_infos=feeder.split_infos)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_train_stats(model, hparams)
        return model, stats
Ejemplo n.º 3
0
def model_train_mode(args, feeder, hparams, global_step):
    with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
        model_name = None
        if args.model == 'Tacotron-2':
            model_name = 'Tacotron'
        model = create_model(model_name or args.model, hparams)
        model.initialize(feeder.inputs,
                         feeder.input_lengths,
                         feeder.feature_targets,
                         feeder.token_targets,
                         targets_lengths=feeder.targets_lengths,
                         global_step=global_step,
                         is_training=True)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_train_stats(model, hparams)
        return model, stats
Ejemplo n.º 4
0
    def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
        log('Constructing model: %s' % model_name)
        #Force the batch size to be known in order to use attention masking in batch synthesis
        inputs = tf.placeholder(tf.int32, (1, None), name='inputs')
        input_lengths = tf.placeholder(tf.int32, (1), name='input_lengths')

        targets = tf.placeholder(tf.float32, (None, None, hparams.num_mels),
                                 name='mel_targets')
        target_lengths = tf.placeholder(tf.int32, (1), name='target_length')
        gta = True

        #initialize(self, inputs, input_lengths, mel_targets=None, stop_token_targets=None,
        # linear_targets=None, targets_lengths=None, gta=False, global_step=None, is_training=False,
        # is_evaluating=False)

        with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
            self.model = create_model(model_name, hparams)
            self.model.initialize(inputs=inputs,
                                  input_lengths=input_lengths,
                                  mel_targets=targets,
                                  targets_lengths=target_lengths,
                                  gta=gta,
                                  is_evaluating=True)

            self.mel_outputs = self.model.mel_outputs
            self.alignments = self.model.alignments

        self._hparams = hparams

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.targets = targets
        self.target_lengths = target_lengths

        log('Loading checkpoint: %s' % checkpoint_path)
        #Memory allocation on the GPUs as needed
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint_path)
Ejemplo n.º 5
0
def model_train_mode(args, feeder, hparams, global_step):
    with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
        model_name = None
        if args.model == 'Tacotron-2':
            model_name = 'Tacotron'
        model = create_model(model_name or args.model, hparams)
        if hparams.predict_linear:
            model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets, linear_targets=feeder.linear_targets,
                targets_lengths=feeder.targets_lengths, global_step=global_step, use_vae=hparams.use_vae, is_training=True, 
                split_infos=feeder.split_infos)
        else:
            model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.token_targets,
                targets_lengths=feeder.targets_lengths, global_step=global_step, use_vae=hparams.use_vae, is_training=True, 
                split_infos=feeder.split_infos)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_train_stats(model, hparams)
        return model, stats
Ejemplo n.º 6
0
    def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
        log('Constructing model: %s' % model_name)
        inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
        input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
        targets = tf.placeholder(tf.float32, [None, None, hparams.num_mels],
                                 'mel_targets')
        with tf.variable_scope('model') as scope:
            self.model = create_model(model_name, hparams)
            if gta:
                self.model.initialize(inputs, input_lengths, targets, gta=gta)
            else:
                self.model.initialize(inputs, input_lengths)
            self.alignments = self.model.alignments
            self.mel_outputs = self.model.mel_outputs
            self.stop_token_prediction = self.model.stop_token_prediction
            if hparams.predict_linear and not gta:
                self.linear_outputs = self.model.linear_outputs
                self.linear_spectrograms = tf.placeholder(
                    tf.float32, (None, hparams.num_freq),
                    name='linear_spectrograms')
                self.linear_wav_outputs = audio.inv_spectrogram_tensorflow(
                    self.linear_spectrograms, hparams)

        self.gta = gta
        self._hparams = hparams
        #pad input sequences with the <pad_token> 0 ( _ )
        self._pad = 0
        #explicitely setting the padding to a value that doesn't originally exist in the spectogram
        #to avoid any possible conflicts, without affecting the output range of the model too much
        if hparams.symmetric_mels:
            self._target_pad = -(hparams.max_abs_value + .1)
        else:
            self._target_pad = -0.1

        log('Loading checkpoint: %s' % checkpoint_path)
        #Memory allocation on the GPU as needed
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint_path)
Ejemplo n.º 7
0
    def load(self, checkpoint_path, gta=False, model_name='Tacotron'):
        print('Constructing model: %s' % model_name)
        inputs = tf.placeholder(tf.int32, [1, None], 'inputs')
        input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')

        with tf.variable_scope('model') as scope:
            self.model = create_model(model_name, hparams)
            if hparams.use_vae:
                ref_targets = tf.placeholder(tf.float32,
                                             [1, None, hparams.num_mels],
                                             'ref_targets')
            if gta:
                targets = tf.placeholder(tf.float32,
                                         [1, None, hparams.num_mels],
                                         'mel_targets')

                if hparams.use_vae:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          targets,
                                          gta=gta,
                                          reference_mel=ref_targets)
                else:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          targets,
                                          gta=gta)
            else:
                if hparams.use_vae:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          reference_mel=ref_targets)
                else:
                    self.model.initialize(inputs, input_lengths)
            self.mel_outputs = self.model.mel_outputs
            self.alignment = self.model.alignments[0]

        self.gta = gta
        print('Loading checkpoint: %s' % checkpoint_path)
        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint_path)
Ejemplo n.º 8
0
def model_test_mode(args, feeder, hparams, global_step):
    with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
        model_name = None
        if args.model == 'Tacotron-2':
            model_name = 'Tacotron'
        model = create_model(model_name or args.model, hparams)
        if hparams.predict_linear:
            model.initialize(feeder.eval_inputs, feeder.eval_input_lengths, feeder.eval_mel_targets,
                             feeder.eval_token_targets,
                             linear_targets=feeder.eval_linear_targets, targets_lengths=feeder.eval_targets_lengths,
                             global_step=global_step,
                             is_training=False, is_evaluating=True)
        else:
            model.initialize(feeder.eval_inputs, feeder.eval_input_lengths, feeder.eval_mel_targets,
                             feeder.eval_token_targets,
                             targets_lengths=feeder.eval_targets_lengths, global_step=global_step, is_training=False,
                             is_evaluating=True)
        model.add_loss()
        return model
Ejemplo n.º 9
0
	def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
		log('Constructing model: %s' % model_name)
		inputs = tf.placeholder(tf.int32, [1, None], 'inputs')
		input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')
		targets = tf.placeholder(tf.float32, [1, None, hparams.num_mels], 'mel_targets')
		with tf.variable_scope('model') as scope:
			self.model = create_model(model_name, hparams)
			if gta:
				self.model.initialize(inputs, input_lengths, targets, gta=gta)
			else:		
				self.model.initialize(inputs, input_lengths)
			self.mel_outputs = self.model.mel_outputs
			self.alignment = self.model.alignments[0]

		self.gta = gta
		self._hparams = hparams

		log('Loading checkpoint: %s' % checkpoint_path)
		self.session = tf.Session()
		self.session.run(tf.global_variables_initializer())
		saver = tf.train.Saver()
		saver.restore(self.session, checkpoint_path)
Ejemplo n.º 10
0
	def load(self, checkpoint_path, hparams, gta=False, vae_code_mode='auto', model_name='Tacotron'):
		log('Constructing model: %s' % model_name)
		#Force the batch size to be known in order to use attention masking in batch synthesis
		inputs = tf.placeholder(tf.int32, (None, None), name='inputs')
		input_lengths = tf.placeholder(tf.int32, (None), name='input_lengths')
		targets = tf.placeholder(tf.float32, (None, None, hparams.num_mels), name='mel_targets')
		lengths = tf.placeholder(tf.float32, (None), name='target_lengths')
		mel_references = tf.placeholder(tf.float32, (None, None, hparams.num_mels), name='mel_references')
		references_lengths = tf.placeholder(tf.float32, (None), name='reference_lengths')
		vae_codes = tf.placeholder(tf.float32, (None, hparams.vae_dim), name='vae_codes')
		split_infos = tf.placeholder(tf.int32, shape=(hparams.tacotron_num_gpus, None), name='split_infos')
		with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
			self.model = create_model(model_name, hparams)
			if gta:
				if hparams.use_vae:
					#Generate vae_code by Gaussian sampling given the mean and variance, which are generated by the VAE network given the mel_targets. Used in GTA synthesis mode.
					self.model.initialize(inputs, input_lengths, mel_targets=targets, targets_lengths=lengths, gta=gta, use_vae=True, split_infos=split_infos)
				else:
					self.model.initialize(inputs, input_lengths, mel_targets=targets, gta=gta, split_infos=split_infos)
			else:
				if hparams.use_vae:
					if vae_code_mode == 'auto':
						#To generate vae_code by Gaussian sampling given the mean and variance, which are generated by the VAE network given the mel_references. Used in natural synthesis mode without args.modify_vae_dim specified.
						self.model.initialize(inputs, input_lengths, mel_references=mel_references, references_lengths=references_lengths, gta=gta, use_vae=True, split_infos=split_infos)
					elif vae_code_mode == 'feed':
						#Directly feed in specified vae_code into the Tacotron decoder network while the VAE network are not used. Used in eval mode when mel_reference is not given, no matter args.modify_vae_dim is specified or not.
						self.model.initialize(inputs, input_lengths, vae_codes=vae_codes, gta=gta, use_vae=True, split_infos=split_infos)
					elif vae_code_mode == 'modify':
						#Directly use the mean generated by the VAE network as vae_code, but with some modification according to the variance. Used in natural synthesis mode with args.modify_vae_dim specified.
						self.model.initialize(inputs, input_lengths, mel_references=mel_references, references_lengths=references_lengths, vae_codes=vae_codes, gta=gta, use_vae=True, split_infos=split_infos)
					elif vae_code_mode == 'inference':
						#To get the vae_code(mean) generated by the VAE given the mel_references. Useful when you wish to check the quality of your VAE latent embedding
						self.model.initialize(mel_references=mel_references, references_lengths=references_lengths, gta=gta, use_vae=True, split_infos=split_infos)
				else:        
					self.model.initialize(inputs, input_lengths, gta=gta, split_infos=split_infos)

			self.mu = self.model.tower_mu
			self.log_var = self.model.tower_log_var
			self.mel_outputs = self.model.tower_mel_outputs
			self.linear_outputs = self.model.tower_linear_outputs if (hparams.predict_linear) else None
			self.alignments = self.model.tower_alignments
			self.stop_token_prediction = self.model.tower_stop_token_prediction

		if hparams.GL_on_GPU:
			self.GLGPU_mel_inputs = tf.placeholder(tf.float32, (None, hparams.num_mels), name='GLGPU_mel_inputs')
			self.GLGPU_lin_inputs = tf.placeholder(tf.float32, (None, hparams.num_freq), name='GLGPU_lin_inputs')

			self.GLGPU_mel_outputs = audio.inv_mel_spectrogram_tensorflow(self.GLGPU_mel_inputs, hparams)
			self.GLGPU_lin_outputs = audio.inv_linear_spectrogram_tensorflow(self.GLGPU_lin_inputs, hparams)

		self.gta = gta
		#force feeding vae codes into the tacotron decoder(for eval mode) or generating the vae codes from the reference mel spectrograms
		self.vae_code_mode = vae_code_mode
		self._hparams = hparams
		#pad input sequences with the <pad_token> 0 ( _ )
		self._pad = 0
		#explicitely setting the padding to a value that doesn't originally exist in the spectogram
		#to avoid any possible conflicts, without affecting the output range of the model too much
		if hparams.symmetric_mels:
			self._target_pad = -hparams.max_abs_value
		else:
			self._target_pad = 0.

		self.inputs = inputs
		self.input_lengths = input_lengths
		self.targets = targets
		self.lengths = lengths
		self.vae_codes = vae_codes
		self.mel_references = mel_references
		self.references_lengths = references_lengths
		self.split_infos = split_infos

		log('Loading checkpoint: %s' % checkpoint_path)
		#Memory allocation on the GPUs as needed
		config = tf.ConfigProto()
		config.gpu_options.allow_growth = True
		config.allow_soft_placement = True

		self.session = tf.Session(config=config)
		self.session.run(tf.global_variables_initializer())

		saver = tf.train.Saver()
		saver.restore(self.session, checkpoint_path)
Ejemplo n.º 11
0
    def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
        log('Constructing model: %s' % model_name)
        #Force the batch size to be known in order to use attention masking in batch synthesis
        inputs = tf.placeholder(tf.int32, (None, None), name='inputs')
        input_lengths = tf.placeholder(tf.int32, (None), name='input_lengths')
        targets = tf.placeholder(tf.float32, (None, None, hparams.num_mels),
                                 name='mel_targets')
        split_infos = tf.placeholder(tf.int32,
                                     shape=(hparams.tacotron_num_gpus, None),
                                     name='split_infos')
        with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
            self.model = create_model(model_name, hparams)
            if gta:
                self.model.initialize(inputs,
                                      input_lengths,
                                      targets,
                                      gta=gta,
                                      split_infos=split_infos)
            else:
                self.model.initialize(inputs,
                                      input_lengths,
                                      split_infos=split_infos)

            self.mel_outputs = self.model.tower_mel_outputs
            self.linear_outputs = self.model.tower_linear_outputs if (
                hparams.predict_linear and not gta) else None
            self.alignments = self.model.tower_alignments
            self.stop_token_prediction = self.model.tower_stop_token_prediction
            self.targets = targets

        if hparams.GL_on_GPU:
            self.GLGPU_mel_inputs = tf.placeholder(tf.float32,
                                                   (None, hparams.num_mels),
                                                   name='GLGPU_mel_inputs')
            self.GLGPU_lin_inputs = tf.placeholder(tf.float32,
                                                   (None, hparams.num_freq),
                                                   name='GLGPU_lin_inputs')

            self.GLGPU_mel_outputs = audio.inv_mel_spectrogram_tensorflow(
                self.GLGPU_mel_inputs, hparams)
            self.GLGPU_lin_outputs = audio.inv_linear_spectrogram_tensorflow(
                self.GLGPU_lin_inputs, hparams)

        self.gta = gta
        self._hparams = hparams
        #pad input sequences with the <pad_token> 0 ( _ )
        self._pad = 0
        #explicitely setting the padding to a value that doesn't originally exist in the spectogram
        #to avoid any possible conflicts, without affecting the output range of the model too much
        if hparams.symmetric_mels:
            self._target_pad = -hparams.max_abs_value
        else:
            self._target_pad = 0.

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.targets = targets
        self.split_infos = split_infos
        OLD_CHECKPOINT_FILE = checkpoint_path
        NEW_CHECKPOINT_FILE = 'logs-Tacotron/taco_pretrained/new_model.ckpt-189500'
        log('Loading checkpoint: %s' % checkpoint_path)
        #UPDATE CHECKPOINT FILE VARS
        vars_to_rename = {
            "Tacotron_model/inference/CBHG_postnet/CBHG_postnet_highwaynet_1/H/biases":
            "Tacotron_model/inference/CBHG_postnet/CBHG_postnet_highwaynet_1/H/bias",
        }
        new_checkpoint_vars = {}
        reader = tf.train.NewCheckpointReader(OLD_CHECKPOINT_FILE)
        for old_name in reader.get_variable_to_shape_map():
            if old_name in vars_to_rename:
                new_name = vars_to_rename[old_name]
            else:
                new_name = old_name
            new_checkpoint_vars[new_name] = tf.Variable(
                reader.get_tensor(old_name))
        #Memory allocation on the GPUs as needed
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())

        saver = tf.train.Saver(new_checkpoint_vars)
        saver.restore(self.session, checkpoint_path)
Ejemplo n.º 12
0
def model_train_mode(args, feeder, hparams, global_step):
    with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
        model_name = 'Tacotron_emt_attn' if args.emt_attn else 'Tacotron'
        model = create_model(model_name, hparams)
        if hparams.predict_linear:
            raise ValueError('predict linear not implemented')
            model.initialize(args,
                             feeder.inputs,
                             feeder.input_lengths,
                             feeder.mel_targets,
                             feeder.token_targets,
                             linear_targets=feeder.linear_targets,
                             targets_lengths=feeder.targets_lengths,
                             global_step=global_step,
                             is_training=True,
                             split_infos=feeder.split_infos,
                             emt_labels=feeder.emt_labels,
                             spk_labels=feeder.spk_labels,
                             emt_up_labels=feeder.emt_up_labels,
                             spk_up_labels=feeder.spk_up_labels,
                             spk_emb=feeder.spk_emb,
                             ref_mel_emt=feeder.ref_mel_emt,
                             ref_mel_spk=feeder.ref_mel_spk,
                             use_emt_disc=args.emt_disc,
                             use_spk_disc=args.spk_disc,
                             use_intercross=args.intercross,
                             n_emt=len(feeder.total_emt),
                             n_spk=len(feeder.total_spk))
        else:
            emt_up_labels = feeder.emt_up_labels if args.unpaired else None
            spk_up_labels = feeder.spk_up_labels if args.unpaired else None
            ref_mel_up_emt = feeder.ref_mel_up_emt if args.unpaired else None
            ref_mel_up_spk = feeder.ref_mel_up_spk if args.unpaired else None

            ref_mel_emt = feeder.ref_mel_emt if not (
                args.flip_spk_emt) else feeder.ref_mel_spk
            ref_mel_spk = feeder.ref_mel_spk if not (
                args.flip_spk_emt) else feeder.ref_mel_emt
            emt_labels = feeder.emt_labels if not (
                args.flip_spk_emt) else feeder.spk_labels
            spk_labels = feeder.spk_labels if not (
                args.flip_spk_emt) else feeder.emt_labels
            n_emt = len(feeder.total_emt) if not (args.flip_spk_emt) else len(
                feeder.total_spk)
            n_spk = len(feeder.total_spk) if not (args.flip_spk_emt) else len(
                feeder.total_emt)

            model.initialize(args,
                             feeder.inputs,
                             feeder.input_lengths,
                             feeder.mel_targets,
                             feeder.token_targets,
                             targets_lengths=feeder.targets_lengths,
                             global_step=global_step,
                             is_training=True,
                             split_infos=feeder.split_infos,
                             emt_labels=emt_labels,
                             spk_labels=spk_labels,
                             emt_up_labels=emt_up_labels,
                             spk_up_labels=spk_up_labels,
                             ref_mel_emt=ref_mel_emt,
                             ref_mel_spk=ref_mel_spk,
                             ref_mel_up_emt=ref_mel_up_emt,
                             ref_mel_up_spk=ref_mel_up_spk,
                             use_emt_disc=args.emt_disc,
                             use_spk_disc=args.spk_disc,
                             use_intercross=args.intercross,
                             use_unpaired=args.unpaired,
                             n_emt=n_emt,
                             n_spk=n_spk)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_train_stats(model, hparams)
        return model, stats
Ejemplo n.º 13
0
    def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
        log('Constructing model: %s' % model_name)
        hparams = hparams.parse(
            'tacotron_num_gpus=1,tacotron_batch_size=32,sharpening=1.3,impute_finished=True')
        self.gta = gta
        self._hparams = hparams
        # pad input sequences with the <pad_token> 0 ( _ )
        self._pad = 0
        # explicitely setting the padding to a value that doesn't originally exist in the spectogram
        # to avoid any possible conflicts, without affecting the output range of the model too much
        if hparams.symmetric_mels:
            self._target_pad = -hparams.max_abs_value
        else:
            self._target_pad = 0.

        # Force the batch size to be known in order to use attention masking in batch synthesis
        inputs = tf.placeholder(tf.int32, (None, None), name='inputs')
        input_lengths = tf.placeholder(tf.int32, (None), name='input_lengths')
        targets = tf.placeholder(
            tf.float32, (None, None, hparams.num_mels), name='mel_targets')
        targets_lengths = tf.placeholder(tf.int32, (None), name='targets_lengths')
        split_infos = tf.placeholder(tf.int32, shape=(
            hparams.tacotron_num_gpus, None), name='split_infos')
        spkid_embeddings = tf.placeholder(
            tf.float32, shape=(None, None), name='spkid_embeddings')
        language_masks = tf.placeholder(
            tf.float32, shape=(None, None), name='language_masks')

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.targets = targets
        self.targets_lengths = targets_lengths
        self.split_infos = split_infos
        self.spkid_embeddings = spkid_embeddings
        self.language_masks = language_masks

        with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE):
            self.model = create_model(model_name, hparams)
            if gta:
                self.model.initialize(
                    inputs, input_lengths, targets, targets_lengths=targets_lengths, gta=gta, split_infos=split_infos)
            else:
                self.model.initialize(inputs, input_lengths, targets, targets_lengths=targets_lengths, spkid_embeddings=spkid_embeddings if hparams.multispeaker else None,
                                      language_masks=language_masks if hparams.add_lang_space else None, split_infos=split_infos)

            self.mel_outputs = self.model.tower_mel_outputs
            self.alignments = self.model.tower_alignments
            self.stop_token_prediction = self.model.tower_stop_token_prediction
            #self.targets = targets
            if hparams.predict_linear and not gta:
                self.linear_outputs = self.model.tower_linear_outputs
                # the first GPU but with all batches
                self.linear_wav_outputs = audio.inv_spectrogram_tensorflow(
                    self.model.tower_linear_outputs[0], hparams)

        log('Loading checkpoint: %s' % checkpoint_path)
        # Memory allocation on the GPUs as needed
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint_path)

        if hparams.singleGPU_no_pyfunc:
            output_node_names = "Tacotron_model/Squeeze,Tacotron_model/inference/Reshape_3"
            output_graph_def = graph_util.convert_variables_to_constants(self.session, tf.get_default_graph().as_graph_def(), output_node_names.split(","))
            model_file = "./saved_model.pb"
            with tf.gfile.GFile(model_file, "wb") as f:
                    f.write(output_graph_def.SerializeToString())       # That's it!
Ejemplo n.º 14
0
    def load(self,
             checkpoint_path,
             hparams,
             gta=False,
             eal=False,
             model_name='tacotron_pml',
             locked_alignments=None,
             logs_enabled=False,
             checkpoint_eal=None,
             flag_online=False):
        if locked_alignments is not None:
            eal = True
        if logs_enabled:
            log('Constructing model: %s' % model_name)

        inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
        input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
        targets = tf.placeholder(tf.float32, [None, None, hparams.num_mels],
                                 'mel_targets')

        with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
            self.model = create_model(model_name, hparams)

            if gta:
                self.model.initialize(inputs,
                                      input_lengths,
                                      mel_targets=targets,
                                      gta=True,
                                      logs_enabled=logs_enabled)
            elif eal:
                self.model.initialize(inputs,
                                      input_lengths,
                                      mel_targets=targets,
                                      eal=True,
                                      locked_alignments=locked_alignments,
                                      logs_enabled=logs_enabled)
            else:
                self.model.initialize(inputs,
                                      input_lengths,
                                      logs_enabled=logs_enabled)

            self.linear_outputs = self.model.linear_outputs

        self.gta, self.eal = gta, eal
        self._hparams = hparams

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.targets = targets

        if logs_enabled:
            log('Loading checkpoint: %s' % checkpoint_path)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())

        if checkpoint_eal is None:
            log('Loading all vars from checkpoint: %s' % checkpoint_path)
            saver = tf.train.Saver()
            saver.restore(self.session, checkpoint_path)
        else:
            list_var = [
                var for var in tf.global_variables()
                if 'Location_Sensitive_Attention' in var.name
                and 'Adam' not in var.name
            ]
            list_var += [
                var for var in tf.global_variables()
                if 'memory_layer' in var.name and 'Adam' not in var.name
            ]

            log('Loading all vars from checkpoint: %s' % checkpoint_eal)
            saver_eal = tf.train.Saver()
            saver_eal.restore(self.session, checkpoint_eal)

            log('Overwriting attention mechanism weights from checkpoint: %s' %
                checkpoint_path)
            saver = tf.train.Saver(list_var)
            saver.restore(self.session, checkpoint_path)
Ejemplo n.º 15
0
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import tensorflow as tf 
from tacotron.models import create_model
from tacotron_hparams import hparams
import shutil 

#with tf.device('/cpu:0'):

inputs = tf.placeholder(tf.int32, [1, None], 'inputs')
input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') 

model_name = 'Tacotron'
with tf.variable_scope('Tacotron_model') as scope:
    model = create_model(model_name, hparams)
    model.initialize(inputs=inputs, input_lengths=input_lengths)


checkpoint_path = tf.train.get_checkpoint_state('./logs-Tacotron-2/taco_pretrained').model_checkpoint_path
#checkpoint_path = './logs-Tacotron-2/taco_pretrained/tacotron_model.ckpt-207000'

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)

export_path_base = './export' 
if os.path.exists(export_path_base):
    shutil.rmtree(export_path_base)
Ejemplo n.º 16
0
    def load(self,
             checkpoint_path,
             hparams,
             gta=False,
             eal=False,
             model_name='tacotron_pml',
             locked_alignments=None,
             logs_enabled=False,
             checkpoint_eal=None,
             flag_online=False):
        if locked_alignments is not None:
            eal = True
        if logs_enabled:
            log('Constructing model: %s' % model_name)

        inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
        input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
        targets = tf.placeholder(tf.float32,
                                 [None, None, hparams.pml_dimension],
                                 'pml_targets')

        with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
            self.model = create_model(model_name, hparams)

            if gta:
                self.model.initialize(inputs,
                                      input_lengths,
                                      pml_targets=targets,
                                      gta=True,
                                      logs_enabled=logs_enabled)
            elif eal:
                # self.model.initialize(inputs, input_lengths, eal=True, locked_alignments=locked_alignments, logs_enabled=logs_enabled)
                if flag_online:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          pml_targets=targets,
                                          eal=True,
                                          locked_alignments=None,
                                          logs_enabled=logs_enabled)
                else:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          pml_targets=targets,
                                          eal=True,
                                          locked_alignments=locked_alignments,
                                          logs_enabled=logs_enabled)
            else:
                if flag_online:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          logs_enabled=logs_enabled,
                                          flag_online_eal_eval=True)
                else:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          logs_enabled=logs_enabled)

            self.pml_outputs = self.model.pml_outputs
            if flag_online: self.pml_outputs_eal = self.model.pml_outputs_eal

        self.gta, self.eal, self.flag_online = gta, eal, flag_online
        self._hparams = hparams

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.targets = targets

        if logs_enabled:
            log('Loading checkpoint: %s' % checkpoint_path)

        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())

        if checkpoint_eal is None:
            log('Loading all vars from checkpoint: %s' % checkpoint_path)
            saver = tf.train.Saver()
            saver.restore(self.session, checkpoint_path)
        else:
            #             import pdb
            #             pdb.set_trace()

            list_var = [
                var for var in tf.global_variables()
                if 'Location_Sensitive_Attention' in var.name
                and 'Adam' not in var.name
            ]
            list_var += [
                var for var in tf.global_variables()
                if 'memory_layer' in var.name and 'Adam' not in var.name
            ]
            #             list_var_value = []
            #             for v in list_var+tf.global_variables()[10:13]:
            #                 list_var_value.append(self.session.run([v]))

            log('Loading all vars from checkpoint: %s' % checkpoint_eal)
            saver_eal = tf.train.Saver()
            saver_eal.restore(self.session, checkpoint_eal)

            #             for i,v in enumerate(list_var+tf.global_variables()[10:13]):
            #                 print(v)
            #                 print(np.array_equal(list_var_value[i], self.session.run([v])))
            #                 list_var_value[i] = self.session.run([v])
            #             pdb.set_trace()

            log('Overwriting attention mechanism weights from checkpoint: %s' %
                checkpoint_path)
            saver = tf.train.Saver(list_var)
            saver.restore(self.session, checkpoint_path)
Ejemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', default='')
    parser.add_argument(
        '--hparams',
        default='',
        help=
        'Hyperparameter overrides as a comma-separated list of name=value pairs'
    )
    parser.add_argument('--model', default='Tacotron-2')
    parser.add_argument('--tacotron_input', default='training_data/train.txt')
    parser.add_argument('--nb_speaker',
                        default=96,
                        help='Number of speaker during training.')
    parser.add_argument('--embedding_dir',
                        default="logs-speaker_embeddings",
                        help='Directory to save the speaker embeddings.')
    parser.add_argument('--debug',
                        default=False,
                        help='Print debugging information')
    parser.add_argument('--verbose',
                        default=True,
                        help='Print progress information')
    args = parser.parse_args()

    debug_flag = args.debug
    verbose_flag = args.verbose

    run_name = args.model
    log_dir = os.path.join(args.base_dir, 'logs-{}'.format(run_name))
    input_path = os.path.join(args.base_dir, args.tacotron_input)
    hparams = hparamspy.parse(args.hparams)
    tensorboard_dir = os.path.join(log_dir, 'tacotron_events')
    save_dir = os.path.join(log_dir, 'taco_pretrained')

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = Feeder(coord, input_path, hparams, split=False)

    # Create model
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
        model = create_model("Tacotron", hparams)
        initialize_args = {
            "inputs": feeder.inputs,
            "input_lengths": feeder.input_lengths,
            "mel_targets": feeder.mel_targets,
            "stop_token_targets": feeder.token_targets,
            "targets_lengths": feeder.targets_lengths,
            "global_step": global_step,
            "is_training": False,
            "split_infos": feeder.split_infos
        }
        if hparams.predict_linear:
            initialize_args["linear_targets"] = feeder.linear_targets
        if hparams.tacotron_reference_waveform:
            initialize_args["mel_references"] = feeder.mel_references
            initialize_args["nb_sample"] = len(feeder._metadata)
        if hparams.tacotron_multi_speaker:
            initialize_args["speaker_id_target"] = feeder.speaker_id_target
            initialize_args["nb_speaker"] = args.nb_speaker
        model.initialize(**initialize_args)

    saver = tf.train.Saver(max_to_keep=5)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:
        summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)

        sess.run(tf.global_variables_initializer())

        # Restore saved model
        checkpoint_state = tf.train.get_checkpoint_state(save_dir)
        saver.restore(sess, checkpoint_state.model_checkpoint_path)

        # Embeddings speaker metadata
        os.makedirs(args.embedding_dir, exist_ok=False)
        speaker_embedding_meta = os.path.join(args.embedding_dir,
                                              'SpeakerEmbeddings.tsv')
        with open(speaker_embedding_meta, 'w', encoding='utf-8') as f:
            f.write("Filename\tSpeaker\n")  # Header

            n = feeder._hparams.tacotron_batch_size
            r = feeder._hparams.outputs_per_step
            speaker_embeddings = []
            examples = []

            if debug_flag:
                print(len(feeder._train_meta))
                print(len(feeder._train_meta[0]))
                print(n * _batches_per_group)

            # Extract speaker label and embedding
            for i in range(n * _batches_per_group):
                # if i<10:
                if i < len(feeder._train_meta):
                    example = feeder._get_next_example()
                    metadata = feeder._train_meta[i]
                    f.write('{}\t{}\n'.format(metadata[1], metadata[-1]))
                    examples.append(example)

                    batch = [example]
                    feed_dict = dict(
                        zip(feeder._placeholders,
                            feeder._prepare_batch(batch, r)))
                    sess.run(feeder._enqueue_op, feed_dict=feed_dict)
                    speaker_embedding = sess.run([model.embedding_speaker])
                    speaker_embeddings.append(speaker_embedding)

                    if verbose_flag:
                        print("\r\r\r\r\r\r\r\r{}/{}".format(
                            i, len(feeder._train_meta)),
                              end=" ")
            if verbose_flag:
                print(" ")

        # Reshape the embeddings data
        speaker_embeddings = np.array(speaker_embeddings)
        if debug_flag:
            print(speaker_embeddings.shape)
        speaker_embeddings = speaker_embeddings.reshape((-1, 64))
        if debug_flag:
            print(speaker_embeddings.shape)

        # Save embeddings data for Tensorboard
        spk_emb = tf.Variable(speaker_embeddings, name='speaker_embeddings')
        with tf.Session(config=config) as sess:
            saver = tf.train.Saver([spk_emb])

            sess.run(spk_emb.initializer)
            saver.save(
                sess,
                os.path.join(args.embedding_dir, 'speaker_embeddings.ckpt'))

            config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
            # One can add multiple embeddings.
            embedding = config.embeddings.add()
            embedding.tensor_name = spk_emb.name
            # Link this tensor to its metadata file (e.g. labels).
            embedding.metadata_path = 'SpeakerEmbeddings.tsv'
            # Saves a config file that TensorBoard will read during startup.
            tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                tf.summary.FileWriter("logs-speaker_embeddings"), config)
Ejemplo n.º 18
0
    def load(self,
             checkpoint,
             hparams,
             gta=False,
             model_name='tacotron_pml',
             locked_alignments=None,
             cut_lengths=True):
        """
        :param checkpoint:
        :param hparams:
        :param gta:
        :param model_name:
        :param locked_alignments:
        :param cut_lengths: boolean flag that controls whether to cut output sequence lengths from the target data
        :return:
        """
        print('Constructing model: %s' % model_name)
        inputs = tf.placeholder(tf.int32, [None, None], 'inputs')
        input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
        if model_name in ['tacotron_bk2orig']:
            targets = tf.placeholder(tf.float32,
                                     [None, None, hparams.num_mels],
                                     'mel_targets')
        else:
            targets = tf.placeholder(tf.float32,
                                     [None, None, hparams.pml_dimension],
                                     'pml_targets')

        with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
            self.model = create_model(model_name, hparams)

            if gta:
                if model_name in ['tacotron_bk2orig']:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          mel_targets=targets,
                                          gta=gta,
                                          locked_alignments=locked_alignments)
                else:
                    self.model.initialize(inputs,
                                          input_lengths,
                                          pml_targets=targets,
                                          gta=gta,
                                          locked_alignments=locked_alignments)
            else:
                self.model.initialize(inputs,
                                      input_lengths,
                                      locked_alignments=locked_alignments)

            self.alignments = self.model.alignments

        self.gta = gta
        self._hparams = hparams
        self.targets = targets
        self.cut_lengths = cut_lengths

        print('Loading checkpoint: %s' % checkpoint)
        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint)
Ejemplo n.º 19
0
def train(log_dir, args, input):
    commit = get_git_commit() if args.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, input)
    log('Checkpoint path: %s' % checkpoint_path)
    log('Loading training data from: %s' % input_path)
    log('Using model: %s' % args.variant)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        if args.eal_dir:
            from tacotron.datafeeder import DataFeeder_EAL
            feeder = DataFeeder_EAL(coord, input_path, hparams, args.eal_dir)
        else:
            from tacotron.datafeeder import DataFeeder
            feeder = DataFeeder(coord, input_path, hparams)

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.variant, hparams)
        if args.eal_dir:
            model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets,
                             feeder.linear_targets, feeder.pml_targets, is_training=True, 
                             eal=True, locked_alignments=feeder.locked_alignments, 
                             flag_trainAlign=args.eal_trainAlign, flag_trainJoint=args.eal_trainJoint, alignScale=args.eal_alignScale)
        else:
            model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets,
                             feeder.linear_targets, feeder.pml_targets, is_training=True, 
                             gta=True)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model, eal_dir=args.eal_dir)

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

    # Set up fixed alignment synthesizer
    alignment_synth = AlignmentSynthesizer()

    # Set up text for synthesis
    fixed_sentence = 'Scientists at the CERN laboratory say they have discovered a new particle.'

    # Set up denormalisation parameters for synthesis
    mean_path = os.path.abspath(os.path.join(args.base_dir, input, '..', 'pml_data/mean.dat'))
    std_path = os.path.abspath(os.path.join(args.base_dir, input, '..', 'pml_data/std.dat'))
    log('Loading normalisation mean from: {}'.format(mean_path))
    log('Loading normalisation standard deviation from: {}'.format(std_path))
    mean_norm = None
    std_norm = None

    if os.path.isfile(mean_path) and os.path.isfile(std_path):
        mean_norm = np.fromfile(mean_path, 'float32')
        std_norm = np.fromfile(std_path, 'float32')

    # Train!
#     import pdb
#     flag_pdb = False
#     pdb.set_trace()
#     args.checkpoint_interval = 2
#     args.num_steps = 5
    
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())
            
#             pdb.set_trace()
            
            if args.restore_step:
                # Restore from a checkpoint if the user requested it.
                restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
            elif args.eal_dir and args.eal_ckpt:
                if args.eal_trainAlign or args.eal_trainJoint:
                    list_var = tf.trainable_variables() + [v for v in tf.global_variables() if 'moving' in v.name]
                    saver_eal = tf.train.Saver(list_var)
                    saver_eal.restore(sess, args.eal_ckpt)
                    log('Loaded weights and batchNorm cache of checkpoint: %s at commit: %s' % (args.eal_ckpt, commit), slack=True)
                elif args.eal_ft:
                    saver.restore(sess, args.eal_ckpt)
                    log('Refining the model from checkpoint: %s at commit: %s' % (args.eal_ckpt, commit), slack=True)
                else:
                    list_var = [var for var in tf.global_variables() if 'optimizer' not in var.name]
                    saver_eal = tf.train.Saver(list_var)
                    saver_eal.restore(sess, args.eal_ckpt)
                    log('Initializing the weights from checkpoint: %s at commit: %s' % (args.eal_ckpt, commit), slack=True)
#                 args.num_steps *= 2
#                 sess.run(global_step.assign(0))
            else:
                log('Starting new training run at commit: %s' % commit, slack=True)

            feeder.start_in_session(sess)
            step = 0  # initialise step variable so can use in while condition
            
            while not coord.should_stop() and step <= args.num_steps:
                
#                 pdb.set_trace()
                                
                start_time = time.time()
                if args.eal_trainAlign:
                    step, loss, loss_align, opt = sess.run([global_step, model.loss, model.loss_align, model.optimize])
#                     try:
#                         step, loss, loss_align, opt, tmp_a, tmp_ar = sess.run([global_step, model.loss, model.loss_align, model.optimize, 
#                                                                                model.alignments, model.alignments_ref])
#                     except:
#                         print("Oops!",sys.exc_info()[0],"occured.")
#                         flag_pdb = True
#                     if flag_pdb or np.isnan(loss_align):
#                         pdb.set_trace()
#                         flag_pdb = False
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss_align)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, loss_align=%.05f, avg_loss_align=%.05f]' % (
                        step, time_window.average, loss, loss_align, loss_window.average)
                elif args.eal_trainJoint:
                    step, loss, loss_align, loss_joint, opt = sess.run([global_step, model.loss, model.loss_align, 
                                                                        model.loss_joint, model.optimize])
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss_joint)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, loss_align=%.05f, avg_loss_joint=%.05f]' % (
                        step, time_window.average, loss, loss_align, loss_window.average)
                else:
                    step, loss, opt = sess.run([global_step, model.loss, model.optimize])
                    time_window.append(time.time() - start_time)
                    loss_window.append(loss)
                    message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                        step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % args.checkpoint_interval == 0))
                
                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
                    raise Exception('Loss Exploded')

                if step % args.summary_interval == 0:
                    log('Writing summary at step: %d' % step)
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    log('Saving audio and alignment...')
                    summary_elements = []

                    # if the model has linear spectrogram features, use them to synthesize audio
                    if hasattr(model, 'linear_targets'):
                        input_seq, alignment, target_spectrogram, spectrogram = sess.run([
                            model.inputs[0], model.alignments[0], model.linear_targets[0], model.linear_outputs[0]])

                        output_waveform = audio.inv_spectrogram(spectrogram.T)
                        target_waveform = audio.inv_spectrogram(target_spectrogram.T)
                        audio.save_wav(output_waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
                        audio.save_wav(target_waveform, os.path.join(log_dir, 'step-%d-target-audio.wav' % step))
                    # otherwise, synthesize audio from PML vocoder features
                    elif hasattr(model, 'pml_targets'):
                        input_seq, alignment, target_pml_features, pml_features = sess.run([
                            model.inputs[0], model.alignments[0], model.pml_targets[0], model.pml_outputs[0]])

                        cfg = Configuration(hparams.sample_rate, hparams.pml_dimension)
                        synth = PMLSynthesizer(cfg)
                        output_waveform = synth.pml_to_wav(pml_features, mean_norm=mean_norm, std_norm=std_norm,
                                                           spec_type=hparams.spec_type)
                        target_waveform = synth.pml_to_wav(target_pml_features, mean_norm=mean_norm, std_norm=std_norm,
                                                           spec_type=hparams.spec_type)

                        sp.wavwrite(os.path.join(log_dir, 'step-%d-target-audio.wav' % step), target_waveform,
                                    hparams.sample_rate, norm_max_ifneeded=True)
                        sp.wavwrite(os.path.join(log_dir, 'step-%d-audio.wav' % step), output_waveform,
                                    hparams.sample_rate, norm_max_ifneeded=True)

                    # we need to adjust the output and target waveforms so the values lie in the interval [-1.0, 1.0]
                    output_waveform /= 1.05 * np.max(np.abs(output_waveform))
                    target_waveform /= 1.05 * np.max(np.abs(target_waveform))

                    summary_elements.append(
                        tf.summary.audio('ideal-%d' % step, np.expand_dims(target_waveform, 0), hparams.sample_rate),
                    )

                    summary_elements.append(
                        tf.summary.audio('sample-%d' % step, np.expand_dims(output_waveform, 0), hparams.sample_rate),
                    )

                    # get the alignment for the top sentence in the batch
                    random_attention_plot = plot.plot_alignment(alignment, os.path.join(log_dir,
                                                                                        'step-%d-random-align.png' % step),
                                                                info='%s, %s, %s, step=%d, loss=%.5f' % (
                                                                args.variant, commit, time_string(), step, loss))

                    summary_elements.append(
                        tf.summary.image('attention-%d' % step, random_attention_plot),
                    )

                    # also process the alignment for a fixed sentence for comparison
                    alignment_synth.load('%s-%d' % (checkpoint_path, step), hparams, model_name=args.variant)
                    fixed_alignment = alignment_synth.synthesize(fixed_sentence)
                    fixed_attention_plot = plot.plot_alignment(fixed_alignment,
                                                               os.path.join(log_dir, 'step-%d-fixed-align.png' % step),
                                                               info='%s, %s, %s, step=%d, loss=%.5f' % (
                                                               args.variant, commit, time_string(), step, loss))

                    summary_elements.append(
                        tf.summary.image('fixed-attention-%d' % step, fixed_attention_plot),
                    )

                    # save the audio and alignment to tensorboard (audio sample rate is hyperparameter)
                    merged = sess.run(tf.summary.merge(summary_elements))

                    summary_writer.add_summary(merged, step)

                    log('Input: %s' % sequence_to_text(input_seq))

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 20
0
    def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
        log('Constructing model: %s' % model_name)
        #Force the batch size to be known in order to use attention masking in batch synthesis
        if not hparams.tacotron_phoneme_transcription:
            inputs = tf.placeholder(tf.int32, (None, None), name='inputs')
        else:
            inputs = tf.placeholder(tf.int32, (None, None, None),
                                    name='inputs')
        input_lengths = tf.placeholder(tf.int32, (None), name='input_lengths')
        targets = tf.placeholder(tf.float32, (None, None, hparams.num_mels),
                                 name='mel_targets')
        split_infos = tf.placeholder(tf.int32,
                                     shape=(hparams.tacotron_num_gpus, None),
                                     name='split_infos')
        if hparams.tacotron_reference_waveform:
            mel_references = tf.placeholder(tf.float32,
                                            (None, None, hparams.num_mels),
                                            name='mel_references')

        with tf.variable_scope('Tacotron_model') as scope:
            self.model = create_model(model_name, hparams)

            initialize_args = {
                "inputs": inputs,
                "input_lengths": input_lengths,
                "split_infos": split_infos,
                "synthesize": True
            }

            if gta:
                initialize_args["gta"] = gta
                initialize_args["mel_targets"] = targets
            if hparams.tacotron_reference_waveform:
                initialize_args["mel_references"] = mel_references
                # initialize_args["nb_sample"] = len(feeder._metadata)
            # if hparams.tacotron_multi_speaker:
            # 	initialize_args["speaker_id_target"] = feeder.speaker_id_target
            # 	initialize_args["nb_speaker"] = feeder._nb_speaker
            self.model.initialize(**initialize_args)

            self.mel_outputs = self.model.tower_mel_outputs
            self.linear_outputs = self.model.tower_linear_outputs if (
                hparams.predict_linear and not gta) else None
            self.alignments = self.model.tower_alignments
            self.stop_token_prediction = self.model.tower_stop_token_prediction
            self.targets = targets
            if hparams.tacotron_reference_waveform:
                self.mel_references = mel_references

        self.gta = gta
        self._hparams = hparams
        #pad input sequences with the <pad_token> 0 ( _ )
        self._pad = 0
        #explicitely setting the padding to a value that doesn't originally exist in the spectogram
        #to avoid any possible conflicts, without affecting the output range of the model too much
        if hparams.symmetric_mels:
            self._target_pad = -hparams.max_abs_value
        else:
            self._target_pad = 0.

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.targets = targets
        self.split_infos = split_infos

        log('Loading checkpoint_backup: %s' % checkpoint_path)
        #Memory allocation on the GPUs as needed
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint_path)
Ejemplo n.º 21
0
import os

checkpoint_dir = "data/LJSpeech-1.1/logs-Tacotron/taco_pretrained/"
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
output_file = "tf.pb_gpu_2"

sess = tf.InteractiveSession()

with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
    with tf.device("/cpu:0"):
        inputs = tf.placeholder(tf.int32, [1, None], name="text")
        inputs_len = tf.placeholder(tf.int32, [1], name="text_len")
        split_infos = tf.placeholder(tf.int32,
                                     shape=(hparams.tacotron_num_gpus, None),
                                     name='split_infos')
        model = create_model("Tacotron", hparams)
        model.initialize(inputs,
                         inputs_len,
                         is_training=False,
                         is_evaluating=False,
                         split_infos=split_infos)
        print("#######")
        print(model.tower_mel_outputs)
        # output = model.tower_mel_outputs[0][0]
        # tf.identity(output, "mel_target", )

        saver = tf.train.Saver(tf.global_variables())
        saver.restore(sess, checkpoint_path)

frozen_graph_def = graph_util.convert_variables_to_constants(
    sess, sess.graph_def, ['Tacotron_model/mel_outputs'])
Ejemplo n.º 22
0
def train(log_dir, args):
    save_dir = os.path.join(log_dir, 'pretrained/')
    checkpoint_path = os.path.join(save_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    plot_dir = os.path.join(log_dir, 'plots')
    wav_dir = os.path.join(log_dir, 'wavs')
    mel_dir = os.path.join(log_dir, 'mel-spectrograms')
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(mel_dir, exist_ok=True)

    if hparams.predict_linear:
        linear_dir = os.path.join(log_dir, 'linear-spectrograms')
        os.makedirs(linear_dir, exist_ok=True)

    log('Checkpoint path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    #Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = Feeder(coord, input_path, hparams)

    #Set up model:
    step_count = 0
    try:
        #simple text file to keep count of global step
        with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file:
            step_count = int(file.read())
    except:
        print(
            'no step_counter file found, assuming there is no saved checkpoint'
        )

    global_step = tf.Variable(step_count, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        if hparams.predict_linear:
            model.initialize(feeder.inputs, feeder.input_lengths,
                             feeder.mel_targets, feeder.token_targets,
                             feeder.linear_targets)
        else:
            model.initialize(feeder.inputs, feeder.input_lengths,
                             feeder.mel_targets, feeder.token_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    #Book keeping
    step = 0
    save_step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5)

    #Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    #Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            #saved model restoring
            if args.restore:
                #Restore saved model if the user requested it, Default = True.
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)
                except tf.errors.OutOfRangeError as e:
                    log('Cannot restore checkpoint: {}'.format(e))

            if (checkpoint_state and checkpoint_state.model_checkpoint_path):
                log('Loading checkpoint {}'.format(
                    checkpoint_state.model_checkpoint_path))
                saver.restore(sess, checkpoint_state.model_checkpoint_path)

            else:
                if not args.restore:
                    log('Starting new training!')
                else:
                    log('No model to load at {}'.format(save_dir))

            #initiating feeder
            feeder.start_in_session(sess)

            #Training loop
            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average)
                log(message, end='\r')

                if loss > 100 or np.isnan(loss):
                    log('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')

                if step % args.summary_interval == 0:
                    log('\nWriting summary at step: {}'.format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    with open(os.path.join(log_dir, 'step_counter.txt'),
                              'w') as file:
                        file.write(str(step))
                    log('Saving checkpoint to: {}-{}'.format(
                        checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    save_step = step

                    log('Saving alignment, Mel-Spectrograms and griffin-lim inverted waveform..'
                        )
                    if hparams.predict_linear:
                        input_seq, mel_prediction, linear_prediction, alignment, target = sess.run(
                            [
                                model.inputs[0],
                                model.mel_outputs[0],
                                model.linear_outputs[0],
                                model.alignments[0],
                                model.mel_targets[0],
                            ])

                        #save predicted linear spectrogram to disk (debug)
                        linear_filename = 'linear-prediction-step-{}.npy'.format(
                            step)
                        np.save(os.path.join(linear_dir, linear_filename),
                                linear_prediction.T,
                                allow_pickle=False)

                        #save griffin lim inverted wav for debug (linear -> wav)
                        wav = audio.inv_linear_spectrogram(linear_prediction.T)
                        audio.save_wav(
                            wav,
                            os.path.join(
                                wav_dir,
                                'step-{}-waveform-linear.wav'.format(step)))

                    else:
                        input_seq, mel_prediction, alignment, target = sess.run(
                            [
                                model.inputs[0],
                                model.mel_outputs[0],
                                model.alignments[0],
                                model.mel_targets[0],
                            ])

                    #save predicted mel spectrogram to disk (debug)
                    mel_filename = 'mel-prediction-step-{}.npy'.format(step)
                    np.save(os.path.join(mel_dir, mel_filename),
                            mel_prediction.T,
                            allow_pickle=False)

                    #save griffin lim inverted wav for debug (mel -> wav)
                    wav = audio.inv_mel_spectrogram(mel_prediction.T)
                    audio.save_wav(
                        wav,
                        os.path.join(wav_dir,
                                     'step-{}-waveform-mel.wav'.format(step)))

                    #save alignment plot to disk (control purposes)
                    plot.plot_alignment(
                        alignment,
                        os.path.join(plot_dir,
                                     'step-{}-align.png'.format(step)),
                        info='{}, {}, step={}, loss={:.5f}'.format(
                            args.model, time_string(), step, loss))
                    #save real mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        target,
                        os.path.join(
                            plot_dir,
                            'step-{}-real-mel-spectrogram.png'.format(step)),
                        info='{}, {}, step={}, Real'.format(
                            args.model, time_string(), step, loss))
                    #save predicted mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        mel_prediction,
                        os.path.join(
                            plot_dir,
                            'step-{}-pred-mel-spectrogram.png'.format(step)),
                        info='{}, {}, step={}, loss={:.5}'.format(
                            args.model, time_string(), step, loss))
                    log('Input at step {}: {}'.format(
                        step, sequence_to_text(input_seq)))

        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 23
0
def train(log_dir, config):
    config.data_paths = config.data_paths

    data_dirs = [os.path.join(data_path, "data") \
            for data_path in config.data_paths]
    num_speakers = len(data_dirs)
    config.num_test = config.num_test_per_speaker * num_speakers

    if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]:
        raise Exception("[!] Unkown model_type for multi-speaker: {}".format(
            config.model_type))

    commit = get_git_commit() if config.git else 'None'
    checkpoint_path = os.path.join(log_dir, 'model.ckpt')

    log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash())
    log('=' * 50)
    #log(' [*] dit diff:\n%s' % get_git_diff())
    log('=' * 50)
    log(' [*] Checkpoint path: %s' % checkpoint_path)
    log(' [*] Loading training data from: %s' % data_dirs)
    log(' [*] Using model: %s' % config.model_dir)
    log(hparams_debug_string())

    # Set up DataFeeder:
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        train_feeder = DataFeeder(coord,
                                  data_dirs,
                                  hparams,
                                  config,
                                  32,
                                  data_type='train',
                                  batch_size=hparams.batch_size)
        test_feeder = DataFeeder(coord,
                                 data_dirs,
                                 hparams,
                                 config,
                                 8,
                                 data_type='test',
                                 batch_size=config.num_test)

    # Set up model:
    is_randomly_initialized = config.initialize_path is None
    global_step = tf.Variable(0, name='global_step', trainable=False)

    with tf.variable_scope('model') as scope:
        model = create_model(hparams)
        model.initialize(train_feeder.inputs,
                         train_feeder.input_lengths,
                         num_speakers,
                         train_feeder.speaker_id,
                         train_feeder.mel_targets,
                         train_feeder.linear_targets,
                         train_feeder.loss_coeff,
                         is_randomly_initialized=is_randomly_initialized)

        model.add_loss()
        model.add_optimizer(global_step)
        train_stats = add_stats(model, scope_name='stats')  # legacy

    with tf.variable_scope('model', reuse=True) as scope:
        test_model = create_model(hparams)
        test_model.initialize(test_feeder.inputs,
                              test_feeder.input_lengths,
                              num_speakers,
                              test_feeder.speaker_id,
                              test_feeder.mel_targets,
                              test_feeder.linear_targets,
                              test_feeder.loss_coeff,
                              rnn_decoder_test_mode=True,
                              is_randomly_initialized=is_randomly_initialized)
        test_model.add_loss()

    test_stats = add_stats(test_model, model, scope_name='test')
    test_stats = tf.summary.merge([test_stats, train_stats])

    # Bookkeeping:
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2)

    sess_config = tf.ConfigProto(log_device_placement=False,
                                 allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    # Train!
    #with tf.Session(config=sess_config) as sess:
    with tf.Session() as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            if config.load_path:
                # Restore from a checkpoint if the user requested it.
                restore_path = get_most_recent_checkpoint(config.model_dir)
                saver.restore(sess, restore_path)
                log('Resuming from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)
            elif config.initialize_path:
                restore_path = get_most_recent_checkpoint(
                    config.initialize_path)
                saver.restore(sess, restore_path)
                log('Initialized from checkpoint: %s at commit: %s' %
                    (restore_path, commit),
                    slack=True)

                zero_step_assign = tf.assign(global_step, 0)
                sess.run(zero_step_assign)

                start_step = sess.run(global_step)
                log('=' * 50)
                log(' [*] Global step is reset to {}'. \
                        format(start_step))
                log('=' * 50)
            else:
                log('Starting new training run at commit: %s' % commit,
                    slack=True)

            start_step = sess.run(global_step)

            train_feeder.start_in_session(sess, start_step)
            test_feeder.start_in_session(sess, start_step)

            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss_without_coeff, model.optimize],
                    feed_dict=model.get_dummy_feed_dict())

                time_window.append(time.time() - start_time)
                loss_window.append(loss)

                message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
                    step, time_window.average, loss, loss_window.average)
                log(message, slack=(step % config.checkpoint_interval == 0))

                if loss > 100 or math.isnan(loss):
                    log('Loss exploded to %.05f at step %d!' % (loss, step),
                        slack=True)
                    raise Exception('Loss Exploded')

                if step % config.summary_interval == 0:
                    log('Writing summary at step: %d' % step)

                    feed_dict = {
                        **model.get_dummy_feed_dict(),
                        **test_model.get_dummy_feed_dict()
                    }
                    summary_writer.add_summary(
                        sess.run(test_stats, feed_dict=feed_dict), step)

                if step % config.checkpoint_interval == 0:
                    log('Saving checkpoint to: %s-%d' %
                        (checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)

                if step % config.test_interval == 0:
                    log('Saving audio and alignment...')
                    num_test = config.num_test

                    fetches = [
                        model.inputs[:num_test],
                        model.linear_outputs[:num_test],
                        model.alignments[:num_test],
                        test_model.inputs[:num_test],
                        test_model.linear_outputs[:num_test],
                        test_model.alignments[:num_test],
                    ]
                    feed_dict = {
                        **model.get_dummy_feed_dict(),
                        **test_model.get_dummy_feed_dict()
                    }

                    sequences, spectrograms, alignments, \
                            test_sequences, test_spectrograms, test_alignments = \
                                    sess.run(fetches, feed_dict=feed_dict)

                    save_and_plot(sequences[:1], spectrograms[:1],
                                  alignments[:1], log_dir, step, loss, "train")
                    save_and_plot(test_sequences, test_spectrograms,
                                  test_alignments, log_dir, step, loss, "test")

        except Exception as e:
            log('Exiting due to exception: %s' % e, slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Ejemplo n.º 24
0
    def load(self, checkpoint_path, hparams, gta=False, model_name='Tacotron'):
        log('Constructing model: %s' % model_name)
        #Force the batch size to be known in order to use attention masking in batch synthesis
        inputs = tf.placeholder(tf.int32, (None, None), name='inputs')
        input_lengths = tf.placeholder(tf.int32, (None), name='input_lengths')
        targets = tf.placeholder(tf.float32, (None, None, hparams.num_mels),
                                 name='mel_targets')
        split_infos = tf.placeholder(tf.int32,
                                     shape=(hparams.tacotron_num_gpus, None),
                                     name='split_infos')
        with tf.variable_scope('Tacotron_model', reuse=tf.AUTO_REUSE) as scope:
            self.model = create_model(model_name, hparams)
            if gta:
                self.model.initialize(inputs,
                                      input_lengths,
                                      targets,
                                      gta=gta,
                                      split_infos=split_infos)
            else:
                self.model.initialize(inputs,
                                      input_lengths,
                                      split_infos=split_infos)

            self.mel_outputs = self.model.tower_mel_outputs
            self.linear_outputs = self.model.tower_linear_outputs if (
                hparams.predict_linear and not gta) else None
            self.alignments = self.model.tower_alignments
            self.stop_token_prediction = self.model.tower_stop_token_prediction
            self.targets = targets

        if hparams.GL_on_GPU:
            self.GLGPU_mel_inputs = tf.placeholder(tf.float32,
                                                   (None, hparams.num_mels),
                                                   name='GLGPU_mel_inputs')
            self.GLGPU_lin_inputs = tf.placeholder(tf.float32,
                                                   (None, hparams.num_freq),
                                                   name='GLGPU_lin_inputs')

            self.GLGPU_mel_outputs = audio.inv_mel_spectrogram_tensorflow(
                self.GLGPU_mel_inputs, hparams)
            self.GLGPU_lin_outputs = audio.inv_linear_spectrogram_tensorflow(
                self.GLGPU_lin_inputs, hparams)

        self.gta = gta
        self._hparams = hparams
        #pad input sequences with the <pad_token> 0 ( _ )
        self._pad = 0
        #explicitely setting the padding to a value that doesn't originally exist in the spectogram
        #to avoid any possible conflicts, without affecting the output range of the model too much
        if hparams.symmetric_mels:
            self._target_pad = -hparams.max_abs_value
        else:
            self._target_pad = 0.

        self.inputs = inputs
        self.input_lengths = input_lengths
        self.targets = targets
        self.split_infos = split_infos

        log('Loading checkpoint: %s' % checkpoint_path)
        #Memory allocation on the GPUs as needed
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True

        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint_path)
Ejemplo n.º 25
0
    def load(self,
             checkpoint_path,
             hparams,
             gta=False,
             model_name='Tacotron',
             model_version=1):
        log('Constructing model: %s' % model_name)
        inputs = tf.placeholder(tf.int32, [None, 999], 'inputs')
        input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths')
        targets = tf.placeholder(tf.float32, [None, None, hparams.num_mels],
                                 'mel_targets')
        with tf.variable_scope('Tacotron_model') as scope:
            self.model = create_model(model_name, hparams)
            if gta:
                self.model.initialize(inputs, input_lengths, targets, gta=gta)
            else:
                self.model.initialize(inputs, input_lengths)
            self.mel_outputs = self.model.mel_outputs
            self.linear_outputs = self.model.linear_outputs if (
                hparams.predict_linear and not gta) else None
            self.stop_token_prediction = self.model.stop_token_prediction
            self.alignments = self.model.alignments

        self.gta = gta
        self._hparams = hparams
        # pad input sequences with the <pad_token> 0 ( _ )
        self._pad = 0
        # explicitely setting the padding to a value that doesn't originally exist in the spectogram
        # to avoid any possible conflicts, without affecting the output range of the model too much
        if hparams.symmetric_mels:
            self._target_pad = -hparams.max_abs_value
        else:
            self._target_pad = 0.

        log('Loading checkpoint: %s' % checkpoint_path)
        # Memory allocation on the GPU as needed
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        self.session = tf.Session(config=config)
        self.session.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(self.session, checkpoint_path)

        export_path_base = 'saved_model'
        model_version = model_version
        export_path = os.path.join(tf.compat.as_bytes(export_path_base),
                                   tf.compat.as_bytes(str(model_version)))
        print('Exporting trained model to', export_path)
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)

        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={
                    'inputs':
                    tf.saved_model.utils.build_tensor_info(inputs),
                    'input_lengths':
                    tf.saved_model.utils.build_tensor_info(input_lengths)
                },
                outputs={
                    'linear_outputs':
                    tf.saved_model.utils.build_tensor_info(
                        tf.convert_to_tensor(self.linear_outputs)),
                    'stop_token':
                    tf.saved_model.utils.build_tensor_info(
                        tf.convert_to_tensor(self.stop_token_prediction))
                },
                method_name=tf.saved_model.signature_constants.
                PREDICT_METHOD_NAME))

        builder.add_meta_graph_and_variables(
            self.session, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                prediction_signature,
            },
            main_op=tf.tables_initializer(),
            strip_default_attrs=True)

        builder.save()

        print("start build converter")
        converter = tf.contrib.lite.TFLiteConverter.from_session(
            self.session,
            input_tensors=[inputs, input_lengths],
            output_tensors=[
                tf.convert_to_tensor(self.linear_outputs),
                tf.convert_to_tensor(self.stop_token_prediction)
            ])

        converter.post_training_quantize = True
        print("start to quntized model")
        tflite_quantized_model = converter.convert()
        open("save_module/quantized_model.tflite",
             "wb").write(tflite_quantized_model)

        print('Done exporting!')