def initialize_training(checkpoint_path):

    # Input dataset definitions
    X = FileSourceDataset(TextDataSource())
    Mel = FileSourceDataset(MelSpecDataSource())
    Y = FileSourceDataset(LinearSpecDataSource())

    # Dataset and Dataloader setup
    dataset = PyTorchDataset(X, Mel, Y)
    data_loader = data.DataLoader(dataset,
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  shuffle=True,
                                  collate_fn=collate_fn,
                                  pin_memory=config.pin_memory)

    # Model
    model = Tacotron(n_vocab=len(symbols),
                     embedding_dim=config.embedding_dim,
                     mel_dim=config.num_mels,
                     linear_dim=config.num_freq,
                     r=config.outputs_per_step,
                     padding_idx=config.padding_idx,
                     use_memory_mask=config.use_memory_mask)

    optimizer = optim.Adam(model.parameters(),
                           lr=config.initial_learning_rate,
                           betas=(config.adam_beta1, config.adam_beta2),
                           weight_decay=config.weight_decay)

    # Load checkpoint
    if checkpoint_path != None:
        print("Load checkpoint from: {}".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        try:
            global_step = checkpoint["global_step"]
            global_epoch = checkpoint["global_epoch"]
        except:
            print('Warning: global step and global epoch unable to restore!')
            sys.exit(0)

    return model, optimizer, data_loader
Exemple #2
0
def main():

	#---initialize---#
	args = get_test_args()

	model = Tacotron(n_vocab=len(symbols),
					 embedding_dim=config.embedding_dim,
					 mel_dim=config.num_mels,
					 linear_dim=config.num_freq,
					 r=config.outputs_per_step,
					 padding_idx=config.padding_idx,
					 use_memory_mask=config.use_memory_mask)

	#---handle path---#
	checkpoint_path = os.path.join(args.ckpt_dir, args.checkpoint_name + args.model_name + '.pth')
	os.makedirs(args.result_dir, exist_ok=True)
	
	#---load and set model---#
	print('Loading model: ', checkpoint_path)
	checkpoint = torch.load(checkpoint_path)
	model.load_state_dict(checkpoint["state_dict"])
	
	if args.long_input:
		model.decoder.max_decoder_steps = 500 # Set large max_decoder steps to handle long sentence outputs
	else:
		model.decoder.max_decoder_steps = 50
		
	if args.interactive == True:
		output_name = args.result_dir + args.model

		#---testing loop---#
		while True:
			try:
				text = str(input('< Tacotron > Text to speech: '))
				text = ch2pinyin(text)
				print('Model input: ', text)
				synthesis_speech(model, text=text, figures=args.plot, path=output_name)
			except KeyboardInterrupt:
				print()
				print('Terminating!')
				break

	elif args.interactive == False:
		output_name = args.result_dir + args.model + '/'
		os.makedirs(output_name, exist_ok=True)

		#---testing flow---#
		with open(args.test_file_path, 'r', encoding='utf-8') as f:
			
			lines = f.readlines()
			for idx, line in enumerate(lines):
				text = ch2pinyin(line)
				print("{}: {} - {} ({} words, {} chars)".format(idx, line, text, len(line), len(text)))
				synthesis_speech(model, text=text, figures=args.plot, path=output_name+line)

		print("Finished! Check out {} for generated audio samples.".format(output_name))
	
	else:
		raise RuntimeError('Invalid mode!!!')
		
	sys.exit(0)