def __init__(self,
                 ckpt_file,
                 device='cuda',
                 use_fp16=False,
                 use_denoiser=False):
        self.ckpt_file = ckpt_file
        self.device = device
        self.use_fp16 = use_fp16
        self.use_denoiser = use_denoiser

        # model
        sys.path.append('waveglow')
        self.model = torch.load(self.ckpt_file,
                                map_location=self.device)['model']
        self.model = self.model.remove_weightnorm(self.model)
        self.model.eval()
        self.model = to_device_async(self.model, self.device)
        if self.use_fp16:
            self.model = self.model.half()
        self.model = self.model

        if self.use_denoiser:
            self.denoiser = Denoiser(self.model, device=device)
            self.denoiser = to_device_async(self.denoiser, self.device)

            tprint('Using WaveGlow denoiser.')
    def __init__(self, data_loader, model_name, model, optimizer_fn, final_steps, lr_scheduler_fn=None, step=0, ckpt_path=None, log_path=None,
                 n_epochs=None, save_steps=None, log_steps=10, device='cuda', use_amp='O0', nvprof_iter_start=None, nvprof_iter_end=None, pyprof_enabled=False, detect_anomaly=False, seed=None, pre_aligns=True):
        super(FastspeechTrainer, self).__init__(data_loader, model_name, model, optimizer_fn, final_steps, lr_scheduler_fn, step, ckpt_path,
                                                log_path, n_epochs, save_steps, log_steps, device, use_amp, nvprof_iter_start, nvprof_iter_end, pyprof_enabled, detect_anomaly, seed)
        self.pre_aligns = pre_aligns

        if not pre_aligns:
            self.tacotron2 = get_tacotron2(device, is_training=True)
            to_device_async(self.tacotron2, device)
    def loss(self, inputs, model):
        text = inputs["text_encoded"]
        text_pos = inputs["text_pos"]
        mel_tgt = inputs["mel"]

        text = to_device_async(text, self.device)
        text_pos = to_device_async(text_pos, self.device)
        mel_tgt = to_device_async(mel_tgt, self.device)

        if self.pre_aligns:
            dur_tgt = inputs["align"]  # preprocessed align
            dur_tgt = dur_tgt.float()
            dur_tgt = to_device_async(dur_tgt, self.device)
        else:
            text_len = inputs['text_len']
            mel_len = inputs['mel_len']
            dur_tgt = get_duration(
                text, text_len, mel_tgt, mel_len, self.tacotron2, self.device)

        # (B,H,T) => (B,T,H)
        mel_tgt = mel_tgt.transpose(1, 2)

        # Forward
        mel, mask, dur = model(
            text,
            text_pos,
            duration_target=dur_tgt,
            seq_output_len=mel_tgt.size(1))
        assert(mel.size(1) == mel_tgt.size(1))

        # Loss
        mel_loss = F.mse_loss(mel, mel_tgt, reduction='none')
        mel_mask = mel_tgt.ne(0).float()
        mel_loss *= mel_mask
        mel_loss = mel_loss.mean()

        dur_tgt = torch.log(dur_tgt + 1)
        dur_mask = text_pos.ne(0).float()
        dur_tgt *= dur_mask

        dur_pred_loss = F.mse_loss(dur, dur_tgt)

        loss = mel_loss + dur_pred_loss

        meta = {
            'mel_loss': to_cpu_numpy(mel_loss),
            'duration_predictor_loss': to_cpu_numpy(dur_pred_loss),
        }
        # meta = {}

        return loss, meta
    def __init__(self, model_name, model, data_loader=None, ckpt_path=None, ckpt_file=None, log_path=None, device='cuda', use_fp16=False, seed=None):
        self.data_loader = data_loader
        self.model_name = model_name
        self.model = model
        self.ckpt_path = ckpt_path
        self.log_path = log_path
        self.device = device
        self.seed = seed
        self.step = 0
        self.ckpt_file = ckpt_file
        self.use_fp16 = use_fp16

        # model
        self.model.eval()
        to_device_async(self.model, self.device)
        num_param = sum(param.numel() for param in model.parameters())
        tprint('The number of {} parameters: {}'.format(self.model_name, num_param))

        # precision
        if self.use_fp16:
            self.model = self.model.half()

        # data parallel
        self.model = nn.DataParallel(self.model)

        # set seed
        if seed is None:
            seed = np.random.randint(2**16)
        np.random.seed(seed)
        torch.manual_seed(seed)

        self.data_loader_iter = iter(self.data_loader)

        # logging
        if log_path:
            # tensorboard log path : {log_path}/YYYYMMDD-HHMMMSS
            log_path = os.path.join(log_path, time.strftime('%Y%m%d-%H%M%S'))
            self.tbwriter = SummaryWriter(log_dir=log_path, flush_secs=10)

        # checkpoint path
        if self.ckpt_path:
            self.ckpt_path = os.path.join(self.ckpt_path, self.model_name)
            pathlib.Path(self.ckpt_path).mkdir(parents=True, exist_ok=True)

            # load checkpoint
            self.load(ckpt_file)
예제 #5
0
    def __init__(self, ckpt_file, device='cuda', use_fp16=False, use_denoiser=False):
        self.ckpt_file = ckpt_file
        self.device = device
        self.use_fp16 = use_fp16
        self.use_denoiser = use_denoiser

        # model
        # sys.path.append('waveglow')

        from waveglow.arg_parser import parse_waveglow_args
        parser = parser = argparse.ArgumentParser()
        model_parser= parse_waveglow_args(parser)
        args, _ = model_parser.parse_known_args()
        model_config = dict(
            n_mel_channels=args.n_mel_channels,
            n_flows=args.flows,
            n_group=args.groups,
            n_early_every=args.early_every,
            n_early_size=args.early_size,
            WN_config=dict(
                n_layers=args.wn_layers,
                kernel_size=args.wn_kernel_size,
                n_channels=args.wn_channels
            )
        )        
        self.model = WaveGlow(**model_config)

        state_dict = torch.load(self.ckpt_file, map_location=self.device)['state_dict']
        state_dict = unwrap_distributed(state_dict)
        self.model.load_state_dict(state_dict)

        self.model = to_device_async(self.model, self.device)

        self.model = self.model.remove_weightnorm(self.model)

        self.model.eval()

        if self.use_fp16:
            self.model = self.model.half()
        self.model = self.model

        if self.use_denoiser:
            self.denoiser = Denoiser(self.model, device=device)
            self.denoiser = to_device_async(self.denoiser, self.device)

            tprint('Using WaveGlow denoiser.')
예제 #6
0
    def infer(self, mels):
        if self.use_fp16:
            mels = mels.half()
        mels = to_device_async(mels, self.device)
        wavs = self.model.infer(mels, sigma=0.6)

        if self.use_denoiser:
            wavs = self.denoiser(wavs, strength=0.01)

        return wavs.float()
예제 #7
0
    def infer(self, acts=None, seq_input_len=None, seq_output_len=None):
        inputs = next(self.data_loader_iter)

        text_encoded = inputs["text_encoded"]
        text_pos = inputs["text_pos"]

        if seq_input_len:
            text_encoded = F.pad(text_encoded,
                                 pad=(0, seq_input_len -
                                      text_encoded.size(1)))  # (b, t)
            text_pos = F.pad(text_pos,
                             pad=(0,
                                  seq_input_len - text_pos.size(1)))  # (b, t)

        text_encoded = to_device_async(text_encoded, self.device)
        text_pos = to_device_async(text_pos, self.device)

        mel, mel_mask, _ = self.model(seq=text_encoded,
                                      pos=text_pos,
                                      seq_output_len=seq_output_len,
                                      use_fp16=self.use_fp16,
                                      acts=acts)

        # (B,T,H) => (B,H,T)
        mel = mel.transpose(1, 2)
        mel_mask = mel_mask.squeeze(2)

        outputs = dict()
        outputs['mel'] = mel
        outputs['mel_mask'] = mel_mask
        outputs['text'] = inputs["text_norm"]

        if "mel" in inputs:
            outputs['mel_tgt'] = inputs["mel"]

        if "wav" in inputs:
            outputs['wav_tgt'] = inputs["wav"]

        if "sr" in inputs:
            outputs['sr'] = inputs["sr"]

        return outputs
예제 #8
0
    def get_output(self, input, duration, alpha):
        output, output_pos = list(), list()
        # TODO: parallelize the loop.
        for i in range(input.size(0)):
            repeats = duration[i].float() * alpha
            with Nvtx("round #{}".format(i), enabled=False):
                repeats = torch.round(repeats).long()
            with Nvtx("repeat #{}".format(i), enabled=False):
                output.append(torch.repeat_interleave(input[i], repeats,
                                                      dim=0))
            output_pos.append(
                torch.from_numpy(np.indices((output[i].shape[0], ))[0] + 1))
        output = pad_sequence(output, batch_first=True)
        output_pos = pad_sequence(output_pos, batch_first=True)

        with Nvtx("pos to gpu", enabled=False):
            output_pos = to_device_async(output_pos, device=output.device)

        return output, output_pos
예제 #9
0
    def __init__(self,
                 data_loader,
                 model_name,
                 model,
                 optimizer_fn,
                 final_steps,
                 lr_scheduler_fn=None,
                 step=0,
                 ckpt_path=None,
                 log_path=None,
                 n_epochs=None,
                 save_steps=None,
                 log_steps=10,
                 device='cuda',
                 use_amp=False,
                 nvprof_iter_start=None,
                 nvprof_iter_end=None,
                 pyprof_enabled=False,
                 detect_anomaly=False,
                 seed=None):
        self.data_loader = data_loader
        self.model_name = model_name
        self.model = model
        self.n_epochs = n_epochs
        self.save_steps = save_steps
        self.log_steps = log_steps
        self.ckpt_path = ckpt_path
        self.log_path = log_path
        self.final_steps = final_steps
        self.step = step
        self.device = device
        self.use_amp = use_amp
        self.nvprof_iter_start = nvprof_iter_start
        self.nvprof_iter_end = nvprof_iter_end
        self.pyprof_enabled = pyprof_enabled
        self.detect_anomaly = detect_anomaly

        # model
        self.model.train()
        to_device_async(self.model, self.device)
        num_param = sum(param.numel() for param in model.parameters())
        tprint('The number of {} parameters: {}'.format(
            self.model_name, num_param))

        # optimizer
        self.optimizer = optimizer_fn(model)

        # lr scheduler
        if lr_scheduler_fn:
            self.lr_scheduler = lr_scheduler_fn(self.optimizer)
        else:
            self.lr_scheduler = None

        # automatic mixed precision
        if self.use_amp:
            from apex import amp
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')

        # profile
        if nvprof_iter_start and nvprof_iter_end is not None and pyprof_enabled:
            from apex import pyprof
            pyprof.nvtx.init()

        # data parallel
        self.model = nn.DataParallel(self.model)

        # set seed
        if seed is None:
            seed = np.random.randint(2**16)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # data loader
        self.data_loader_iter = self.repeat(self.data_loader, n_epochs)

        # logging
        if log_path:
            # tensorboard log path : {log_path}/YYYYMMDD-HHMMMSS
            log_path = os.path.join(log_path, time.strftime('%Y%m%d-%H%M%S'))
            self.tbwriter = SummaryWriter(log_dir=log_path, flush_secs=10)

        # checkpoint path
        if self.ckpt_path:
            self.ckpt_path = os.path.join(self.ckpt_path, self.model_name)
            pathlib.Path(self.ckpt_path).mkdir(parents=True, exist_ok=True)

            # load checkpoint
            self.load()