def initialize_training(checkpoint_path): # Input dataset definitions X = FileSourceDataset(TextDataSource()) Mel = FileSourceDataset(MelSpecDataSource()) Y = FileSourceDataset(LinearSpecDataSource()) # Dataset and Dataloader setup dataset = PyTorchDataset(X, Mel, Y) data_loader = data.DataLoader(dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True, collate_fn=collate_fn, pin_memory=config.pin_memory) # Model model = Tacotron(n_vocab=len(symbols), embedding_dim=config.embedding_dim, mel_dim=config.num_mels, linear_dim=config.num_freq, r=config.outputs_per_step, padding_idx=config.padding_idx, use_memory_mask=config.use_memory_mask) optimizer = optim.Adam(model.parameters(), lr=config.initial_learning_rate, betas=(config.adam_beta1, config.adam_beta2), weight_decay=config.weight_decay) # Load checkpoint if checkpoint_path != None: print("Load checkpoint from: {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) try: global_step = checkpoint["global_step"] global_epoch = checkpoint["global_epoch"] except: print('Warning: global step and global epoch unable to restore!') sys.exit(0) return model, optimizer, data_loader
def main(): #---initialize---# args = get_test_args() model = Tacotron(n_vocab=len(symbols), embedding_dim=config.embedding_dim, mel_dim=config.num_mels, linear_dim=config.num_freq, r=config.outputs_per_step, padding_idx=config.padding_idx, use_memory_mask=config.use_memory_mask) #---handle path---# checkpoint_path = os.path.join(args.ckpt_dir, args.checkpoint_name + args.model_name + '.pth') os.makedirs(args.result_dir, exist_ok=True) #---load and set model---# print('Loading model: ', checkpoint_path) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint["state_dict"]) if args.long_input: model.decoder.max_decoder_steps = 500 # Set large max_decoder steps to handle long sentence outputs else: model.decoder.max_decoder_steps = 50 if args.interactive == True: output_name = args.result_dir + args.model #---testing loop---# while True: try: text = str(input('< Tacotron > Text to speech: ')) text = ch2pinyin(text) print('Model input: ', text) synthesis_speech(model, text=text, figures=args.plot, path=output_name) except KeyboardInterrupt: print() print('Terminating!') break elif args.interactive == False: output_name = args.result_dir + args.model + '/' os.makedirs(output_name, exist_ok=True) #---testing flow---# with open(args.test_file_path, 'r', encoding='utf-8') as f: lines = f.readlines() for idx, line in enumerate(lines): text = ch2pinyin(line) print("{}: {} - {} ({} words, {} chars)".format(idx, line, text, len(line), len(text))) synthesis_speech(model, text=text, figures=args.plot, path=output_name+line) print("Finished! Check out {} for generated audio samples.".format(output_name)) else: raise RuntimeError('Invalid mode!!!') sys.exit(0)