def __init__(self, ckpt_file, device='cuda', use_fp16=False, use_denoiser=False): self.ckpt_file = ckpt_file self.device = device self.use_fp16 = use_fp16 self.use_denoiser = use_denoiser # model sys.path.append('waveglow') self.model = torch.load(self.ckpt_file, map_location=self.device)['model'] self.model = self.model.remove_weightnorm(self.model) self.model.eval() self.model = to_device_async(self.model, self.device) if self.use_fp16: self.model = self.model.half() self.model = self.model if self.use_denoiser: self.denoiser = Denoiser(self.model, device=device) self.denoiser = to_device_async(self.denoiser, self.device) tprint('Using WaveGlow denoiser.')
def __init__(self, data_loader, model_name, model, optimizer_fn, final_steps, lr_scheduler_fn=None, step=0, ckpt_path=None, log_path=None, n_epochs=None, save_steps=None, log_steps=10, device='cuda', use_amp='O0', nvprof_iter_start=None, nvprof_iter_end=None, pyprof_enabled=False, detect_anomaly=False, seed=None, pre_aligns=True): super(FastspeechTrainer, self).__init__(data_loader, model_name, model, optimizer_fn, final_steps, lr_scheduler_fn, step, ckpt_path, log_path, n_epochs, save_steps, log_steps, device, use_amp, nvprof_iter_start, nvprof_iter_end, pyprof_enabled, detect_anomaly, seed) self.pre_aligns = pre_aligns if not pre_aligns: self.tacotron2 = get_tacotron2(device, is_training=True) to_device_async(self.tacotron2, device)
def loss(self, inputs, model): text = inputs["text_encoded"] text_pos = inputs["text_pos"] mel_tgt = inputs["mel"] text = to_device_async(text, self.device) text_pos = to_device_async(text_pos, self.device) mel_tgt = to_device_async(mel_tgt, self.device) if self.pre_aligns: dur_tgt = inputs["align"] # preprocessed align dur_tgt = dur_tgt.float() dur_tgt = to_device_async(dur_tgt, self.device) else: text_len = inputs['text_len'] mel_len = inputs['mel_len'] dur_tgt = get_duration( text, text_len, mel_tgt, mel_len, self.tacotron2, self.device) # (B,H,T) => (B,T,H) mel_tgt = mel_tgt.transpose(1, 2) # Forward mel, mask, dur = model( text, text_pos, duration_target=dur_tgt, seq_output_len=mel_tgt.size(1)) assert(mel.size(1) == mel_tgt.size(1)) # Loss mel_loss = F.mse_loss(mel, mel_tgt, reduction='none') mel_mask = mel_tgt.ne(0).float() mel_loss *= mel_mask mel_loss = mel_loss.mean() dur_tgt = torch.log(dur_tgt + 1) dur_mask = text_pos.ne(0).float() dur_tgt *= dur_mask dur_pred_loss = F.mse_loss(dur, dur_tgt) loss = mel_loss + dur_pred_loss meta = { 'mel_loss': to_cpu_numpy(mel_loss), 'duration_predictor_loss': to_cpu_numpy(dur_pred_loss), } # meta = {} return loss, meta
def __init__(self, model_name, model, data_loader=None, ckpt_path=None, ckpt_file=None, log_path=None, device='cuda', use_fp16=False, seed=None): self.data_loader = data_loader self.model_name = model_name self.model = model self.ckpt_path = ckpt_path self.log_path = log_path self.device = device self.seed = seed self.step = 0 self.ckpt_file = ckpt_file self.use_fp16 = use_fp16 # model self.model.eval() to_device_async(self.model, self.device) num_param = sum(param.numel() for param in model.parameters()) tprint('The number of {} parameters: {}'.format(self.model_name, num_param)) # precision if self.use_fp16: self.model = self.model.half() # data parallel self.model = nn.DataParallel(self.model) # set seed if seed is None: seed = np.random.randint(2**16) np.random.seed(seed) torch.manual_seed(seed) self.data_loader_iter = iter(self.data_loader) # logging if log_path: # tensorboard log path : {log_path}/YYYYMMDD-HHMMMSS log_path = os.path.join(log_path, time.strftime('%Y%m%d-%H%M%S')) self.tbwriter = SummaryWriter(log_dir=log_path, flush_secs=10) # checkpoint path if self.ckpt_path: self.ckpt_path = os.path.join(self.ckpt_path, self.model_name) pathlib.Path(self.ckpt_path).mkdir(parents=True, exist_ok=True) # load checkpoint self.load(ckpt_file)
def __init__(self, ckpt_file, device='cuda', use_fp16=False, use_denoiser=False): self.ckpt_file = ckpt_file self.device = device self.use_fp16 = use_fp16 self.use_denoiser = use_denoiser # model # sys.path.append('waveglow') from waveglow.arg_parser import parse_waveglow_args parser = parser = argparse.ArgumentParser() model_parser= parse_waveglow_args(parser) args, _ = model_parser.parse_known_args() model_config = dict( n_mel_channels=args.n_mel_channels, n_flows=args.flows, n_group=args.groups, n_early_every=args.early_every, n_early_size=args.early_size, WN_config=dict( n_layers=args.wn_layers, kernel_size=args.wn_kernel_size, n_channels=args.wn_channels ) ) self.model = WaveGlow(**model_config) state_dict = torch.load(self.ckpt_file, map_location=self.device)['state_dict'] state_dict = unwrap_distributed(state_dict) self.model.load_state_dict(state_dict) self.model = to_device_async(self.model, self.device) self.model = self.model.remove_weightnorm(self.model) self.model.eval() if self.use_fp16: self.model = self.model.half() self.model = self.model if self.use_denoiser: self.denoiser = Denoiser(self.model, device=device) self.denoiser = to_device_async(self.denoiser, self.device) tprint('Using WaveGlow denoiser.')
def infer(self, mels): if self.use_fp16: mels = mels.half() mels = to_device_async(mels, self.device) wavs = self.model.infer(mels, sigma=0.6) if self.use_denoiser: wavs = self.denoiser(wavs, strength=0.01) return wavs.float()
def infer(self, acts=None, seq_input_len=None, seq_output_len=None): inputs = next(self.data_loader_iter) text_encoded = inputs["text_encoded"] text_pos = inputs["text_pos"] if seq_input_len: text_encoded = F.pad(text_encoded, pad=(0, seq_input_len - text_encoded.size(1))) # (b, t) text_pos = F.pad(text_pos, pad=(0, seq_input_len - text_pos.size(1))) # (b, t) text_encoded = to_device_async(text_encoded, self.device) text_pos = to_device_async(text_pos, self.device) mel, mel_mask, _ = self.model(seq=text_encoded, pos=text_pos, seq_output_len=seq_output_len, use_fp16=self.use_fp16, acts=acts) # (B,T,H) => (B,H,T) mel = mel.transpose(1, 2) mel_mask = mel_mask.squeeze(2) outputs = dict() outputs['mel'] = mel outputs['mel_mask'] = mel_mask outputs['text'] = inputs["text_norm"] if "mel" in inputs: outputs['mel_tgt'] = inputs["mel"] if "wav" in inputs: outputs['wav_tgt'] = inputs["wav"] if "sr" in inputs: outputs['sr'] = inputs["sr"] return outputs
def get_output(self, input, duration, alpha): output, output_pos = list(), list() # TODO: parallelize the loop. for i in range(input.size(0)): repeats = duration[i].float() * alpha with Nvtx("round #{}".format(i), enabled=False): repeats = torch.round(repeats).long() with Nvtx("repeat #{}".format(i), enabled=False): output.append(torch.repeat_interleave(input[i], repeats, dim=0)) output_pos.append( torch.from_numpy(np.indices((output[i].shape[0], ))[0] + 1)) output = pad_sequence(output, batch_first=True) output_pos = pad_sequence(output_pos, batch_first=True) with Nvtx("pos to gpu", enabled=False): output_pos = to_device_async(output_pos, device=output.device) return output, output_pos
def __init__(self, data_loader, model_name, model, optimizer_fn, final_steps, lr_scheduler_fn=None, step=0, ckpt_path=None, log_path=None, n_epochs=None, save_steps=None, log_steps=10, device='cuda', use_amp=False, nvprof_iter_start=None, nvprof_iter_end=None, pyprof_enabled=False, detect_anomaly=False, seed=None): self.data_loader = data_loader self.model_name = model_name self.model = model self.n_epochs = n_epochs self.save_steps = save_steps self.log_steps = log_steps self.ckpt_path = ckpt_path self.log_path = log_path self.final_steps = final_steps self.step = step self.device = device self.use_amp = use_amp self.nvprof_iter_start = nvprof_iter_start self.nvprof_iter_end = nvprof_iter_end self.pyprof_enabled = pyprof_enabled self.detect_anomaly = detect_anomaly # model self.model.train() to_device_async(self.model, self.device) num_param = sum(param.numel() for param in model.parameters()) tprint('The number of {} parameters: {}'.format( self.model_name, num_param)) # optimizer self.optimizer = optimizer_fn(model) # lr scheduler if lr_scheduler_fn: self.lr_scheduler = lr_scheduler_fn(self.optimizer) else: self.lr_scheduler = None # automatic mixed precision if self.use_amp: from apex import amp self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') # profile if nvprof_iter_start and nvprof_iter_end is not None and pyprof_enabled: from apex import pyprof pyprof.nvtx.init() # data parallel self.model = nn.DataParallel(self.model) # set seed if seed is None: seed = np.random.randint(2**16) np.random.seed(seed) torch.manual_seed(seed) # data loader self.data_loader_iter = self.repeat(self.data_loader, n_epochs) # logging if log_path: # tensorboard log path : {log_path}/YYYYMMDD-HHMMMSS log_path = os.path.join(log_path, time.strftime('%Y%m%d-%H%M%S')) self.tbwriter = SummaryWriter(log_dir=log_path, flush_secs=10) # checkpoint path if self.ckpt_path: self.ckpt_path = os.path.join(self.ckpt_path, self.model_name) pathlib.Path(self.ckpt_path).mkdir(parents=True, exist_ok=True) # load checkpoint self.load()