Beispiel #1
0
    def __init__(self,
                 adam_lr=0.002,
                 warmup_epochs=30,
                 init_scale=0.25,
                 guided_att_sigma=0.3,
                 device='cuda'):
        super(DurationExtractor, self).__init__()

        self.txt_encoder = ConvTextEncoder()
        self.audio_encoder = ConvAudioEncoder()
        self.audio_decoder = ConvAudioDecoder()
        self.attention = ScaledDotAttention()
        self.collate = Collate(device=device)

        # optim
        self.optimizer = torch.optim.Adam(self.parameters(), lr=adam_lr)
        self.scheduler = NoamScheduler(self.optimizer, warmup_epochs,
                                       init_scale)

        # losses
        self.loss_l1 = l1_masked
        self.loss_att = GuidedAttentionLoss(guided_att_sigma)

        # device
        self.device = device
        self.to(self.device)
        print(f'Model sent to {self.device}')

        # helper vars
        self.checkpoint = None
        self.epoch = 0
        self.step = 0
Beispiel #2
0
def run(args):
    """Runs the algorithm."""
    Path(hp.output_path).mkdir(parents=True, exist_ok=True)

    # setup nnabla context
    ctx = get_extension_context(args.context, device_id='0')
    nn.set_default_context(ctx)
    hp.comm = CommunicatorWrapper(ctx)
    hp.event = StreamEventHandler(int(hp.comm.ctx.device_id))

    if hp.comm.n_procs > 1 and hp.comm.rank == 0:
        n_procs = hp.comm.n_procs
        logger.info(f'Distributed training with {n_procs} processes.')

    rng = np.random.RandomState(hp.seed)

    # setup optimizer
    lr_scheduler = NoamScheduler(hp.alpha, warmup=hp.warmup)
    optimizer = Optimizer(weight_decay=hp.weight_decay,
                          max_norm=hp.max_norm,
                          lr_scheduler=lr_scheduler,
                          name='Adam',
                          alpha=hp.alpha)

    # train data
    train_loader = data_iterator(LJSpeechDataSource('metadata_train.csv',
                                                    hp,
                                                    shuffle=True,
                                                    rng=rng),
                                 batch_size=hp.batch_size,
                                 with_memory_cache=False)
    # valid data
    valid_loader = data_iterator(LJSpeechDataSource('metadata_valid.csv',
                                                    hp,
                                                    shuffle=False,
                                                    rng=rng),
                                 batch_size=hp.batch_size,
                                 with_memory_cache=False)
    dataloader = dict(train=train_loader, valid=valid_loader)
    model = Tacotron(hp)

    TacotronTrainer(model, dataloader, optimizer, hp).run()
Beispiel #3
0
    def __init__(self,
                 hparams,
                 adam_lr=0.002,
                 warmup_epochs=30,
                 init_scale=0.25,
                 checkpoint=None,
                 device='cuda'):
        self.hparams = hparams
        model = DurationExtractor(hparams.duration)
        dataset_root = osp.join(hparams.data.datasets_path,
                                hparams.data.dataset_dir)
        dataset = SpeechDataset(['mels', 'mlens', 'texts', 'tlens'],
                                dataset_root, hparams.text)
        compute_metrics = self.recon_losses
        optimizer = torch.optim.Adam(model.parameters(), lr=adam_lr)
        scheduler = NoamScheduler(optimizer, warmup_epochs, init_scale)
        optimizers = (optimizer, scheduler)

        super(DurationTrainer, self).__init__(model=model,
                                              dataset=dataset,
                                              compute_metrics=compute_metrics,
                                              optimizers=optimizers,
                                              checkpoint=checkpoint,
                                              device=device)
Beispiel #4
0
class DurationExtractor(nn.Module):
    """The teacher model for duration extraction"""
    def __init__(self,
                 adam_lr=0.002,
                 warmup_epochs=30,
                 init_scale=0.25,
                 guided_att_sigma=0.3,
                 device='cuda'):
        super(DurationExtractor, self).__init__()

        self.txt_encoder = ConvTextEncoder()
        self.audio_encoder = ConvAudioEncoder()
        self.audio_decoder = ConvAudioDecoder()
        self.attention = ScaledDotAttention()
        self.collate = Collate(device=device)

        # optim
        self.optimizer = torch.optim.Adam(self.parameters(), lr=adam_lr)
        self.scheduler = NoamScheduler(self.optimizer, warmup_epochs,
                                       init_scale)

        # losses
        self.loss_l1 = l1_masked
        self.loss_att = GuidedAttentionLoss(guided_att_sigma)

        # device
        self.device = device
        self.to(self.device)
        print(f'Model sent to {self.device}')

        # helper vars
        self.checkpoint = None
        self.epoch = 0
        self.step = 0

        #repo = git.Repo(search_parent_directories=True)
        #self.git_commit = repo.head.object.hexsha

    def to_device(self, device):
        print(f'Sending network to {device}')
        self.device = device
        self.to(device)
        return self

    def save(self):

        if self.checkpoint is not None:
            os.remove(self.checkpoint)
        self.checkpoint = os.path.join(
            self.logger.log_dir,
            f'{time.strftime("%Y-%m-%d")}_checkpoint_step{self.step}.pth')
        torch.save(
            {
                'epoch': self.epoch,
                'step': self.step,
                'state_dict': self.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict(),
                #'git_commit': self.git_commit
            },
            self.checkpoint)
        print(
            'finished save checkpoint at : ',
            os.path.join(
                self.logger.log_dir,
                f'{time.strftime("%Y-%m-%d")}_checkpoint_step{self.step}.pth'))

    def load(self, checkpoint):
        checkpoint = torch.load(checkpoint)
        self.epoch = checkpoint['epoch']
        self.step = checkpoint['step']
        self.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])

        #commit = checkpoint['git_commit']
        #if commit != self.git_commit:
        #    print(f'Warning: the loaded checkpoint was trained on commit {commit}, but you are on {self.git_commit}')
        self.checkpoint = None  # prevent overriding old checkpoint
        return self

    def forward(self, phonemes, spectrograms, len_phonemes, training=False):
        """
        :param phonemes: (batch, alphabet, time), padded phonemes
        :param spectrograms: (batch, freq, time), padded spectrograms
        :param len_phonemes: list of phoneme lengths
        :return: decoded_spectrograms, attention_weights
        """
        spectrs = ZeroPad2d(
            (0, 0, 1, 0))(spectrograms)[:, :-1, :]  # move this to encoder?
        keys, values = self.txt_encoder(phonemes)
        queries = self.audio_encoder(spectrs)

        att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]),
                        lengths=len_phonemes,
                        dim=-1).to(self.device)

        if hp.positional_encoding:
            keys += positional_encoding(keys.shape[-1], keys.shape[1],
                                        w=hp.w).to(self.device)
            queries += positional_encoding(queries.shape[-1],
                                           queries.shape[1],
                                           w=1).to(self.device)

        attention, weights = self.attention(queries,
                                            keys,
                                            values,
                                            mask=att_mask)
        decoded = self.audio_decoder(attention + queries)
        return decoded, weights

    def generating(self, mode):
        """Put the module into mode for sequential generation"""
        for module in self.children():
            if hasattr(module, 'generating'):
                module.generating(mode)

    def generate(self,
                 phonemes,
                 len_phonemes,
                 steps=False,
                 window=3,
                 spectrograms=None):
        """Sequentially generate spectrogram from phonemes

        If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing)
        If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion
        Uses layer-level caching for faster inference.

        :param phonemes: Padded phoneme indices
        :param len_phonemes: Length of each sentence in `phonemes` (list of lengths)
        :param steps: How many steps to generate
        :param window: Window size for attention masking
        :param spectrograms: Padded spectrograms
        :return: Generated spectrograms
        """
        self.generating(True)
        self.train(False)

        assert steps or (spectrograms is not None)
        steps = steps if steps else spectrograms.shape[1]

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)
            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1],
                                            keys.shape[1],
                                            w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps,
                                         w=1).to(self.device)

            if spectrograms is None:
                dec = torch.zeros(len(phonemes),
                                  1,
                                  hp.out_channels,
                                  device=self.device)
            else:
                input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :]

            weights, decoded = None, None

            if window is not None:
                shape = (len(phonemes), 1, phonemes.shape[-1])
                idx = torch.zeros(len(phonemes), 1,
                                  phonemes.shape[-1]).to(phonemes.device)
                att_mask = idx_mask(shape, idx, window)
            else:
                att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                                lengths=len_phonemes,
                                dim=-1).to(self.device)

            for i in range(steps):
                if spectrograms is None:
                    queries = self.audio_encoder(dec)
                else:
                    queries = self.audio_encoder(input[:, i:i + 1, :])

                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                dec = self.audio_decoder(att + queries)
                weights = w if weights is None else torch.cat(
                    (weights, w), dim=1)
                decoded = dec if decoded is None else torch.cat(
                    (decoded, dec), dim=1)
                if window is not None:
                    idx = torch.argmax(w, dim=-1).unsqueeze(2).float()
                    att_mask = idx_mask(shape, idx, window)

        self.generating(False)
        return decoded, weights

    def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0, 1)):
        """Naive generation without layer-level caching for testing purposes"""

        self.train(False)

        with torch.no_grad():
            phonemes = torch.as_tensor(phonemes)

            keys, values = self.txt_encoder(phonemes)

            if hp.positional_encoding:
                keys += positional_encoding(keys.shape[-1],
                                            keys.shape[1],
                                            w=hp.w).to(self.device)
                pe = positional_encoding(hp.channels, steps,
                                         w=1).to(self.device)

            dec = torch.zeros(len(phonemes),
                              1,
                              hp.out_channels,
                              device=self.device)

            weights = None

            att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]),
                            lengths=len_phonemes,
                            dim=-1).to(self.device)

            for i in range(steps):
                print(i)
                queries = self.audio_encoder(dec)
                if hp.positional_encoding:
                    queries += pe[i]

                att, w = self.attention(queries, keys, values, att_mask)
                d = self.audio_decoder(att + queries)
                d = d[:, -1:]
                w = w[:, -1:]
                weights = w if weights is None else torch.cat(
                    (weights, w), dim=1)
                dec = torch.cat((dec, d), dim=1)

                if window is not None:
                    att_mask = median_mask(weights, window=window)

        return dec[:, 1:, :], weights

    def fit(self,
            batch_size,
            logdir,
            epochs=1,
            grad_clip=1,
            checkpoint_every=10):
        self.grad_clip = grad_clip
        self.logger = SummaryWriter(logdir)

        train_loader = self.train_dataloader(batch_size)
        valid_loader = self.val_dataloader(batch_size)

        # continue training from self.epoch if checkpoint loaded
        for e in range(self.epoch + 1, self.epoch + 1 + epochs):
            self.epoch = e
            train_losses = self._train_epoch(train_loader)
            valid_losses = self._validate(valid_loader)

            self.scheduler.step()
            self.logger.add_scalar('train/learning_rate',
                                   self.optimizer.param_groups[0]['lr'],
                                   self.epoch)
            if not e % checkpoint_every:
                self.save()

            print(
                f'Epoch {e} | Train - l1: {train_losses[0]}, guided_att: {train_losses[1]}| '
                f'Valid - l1: {valid_losses[0]}, guided_att: {valid_losses[1]}|'
            )

    def _train_epoch(self, dataloader):
        self.train()

        t_l1, t_att = 0, 0
        for i, batch in enumerate(Bar(dataloader)):
            self.optimizer.zero_grad()
            spectrs, slen, phonemes, plen, text = batch

            s = add_random_noise(spectrs, hp.noise)
            s = degrade_some(self,
                             s,
                             phonemes,
                             plen,
                             hp.feed_ratio,
                             repeat=hp.feed_repeat)
            s = frame_dropout(s, hp.replace_ratio)

            out, att_weights = self.forward(phonemes, s, plen)

            l1 = self.loss_l1(out, spectrs, slen)
            l_att = self.loss_att(att_weights, slen, plen)

            loss = l1 + l_att
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
            self.optimizer.step()
            self.step += 1

            t_l1 += l1.item()
            t_att += l_att.item()

            self.logger.add_scalar('batch/total', loss.item(), self.step)

        # report average cost per batch
        self.logger.add_scalar('train/l1', t_l1 / i, self.epoch)
        self.logger.add_scalar('train/guided_att', t_att / i, self.epoch)
        return t_l1 / i, t_att / i

    def _validate(self, dataloader):
        self.eval()

        t_l1, t_att = 0, 0
        for i, batch in enumerate(dataloader):
            spectrs, slen, phonemes, plen, text = batch
            # generate sequentially
            out, att_weights = self.generate(phonemes,
                                             plen,
                                             steps=spectrs.shape[1],
                                             window=None)

            # generate in supervised fashion - for visualisation only
            with torch.no_grad():
                out_s, att_s = self.forward(phonemes, spectrs, plen)

            l1 = self.loss_l1(out, spectrs, slen)
            l_att = self.loss_att(att_weights, slen, plen)
            t_l1 += l1.item()
            t_att += l_att.item()

            fig = display_spectr_alignment(
                out[-1, :slen[-1]], att_weights[-1][:slen[-1], :plen[-1]],
                out_s[-1, :slen[-1]], att_s[-1][:slen[-1], :plen[-1]],
                text[-1])
            self.logger.add_figure(text[-1], fig, self.epoch)
            '''
            if not self.epoch % 1:
                spec = self.collate.norm.inverse(out[-1:]) # TODO: this fails if we do not standardize!
                sound, length = self.collate.stft.spec2wav(spec.transpose(1, 2), slen[-1:])
                sound = sound[0, :length[0]]
                self.logger.add_audio(text[-1], sound.detach().cpu().numpy(), self.epoch, sample_rate=22050) # TODO: parameterize
            '''
        # report average cost per batch
        self.logger.add_scalar('valid/l1', t_l1 / i, self.epoch)
        self.logger.add_scalar('valid/guided_att', t_att / i, self.epoch)
        return t_l1 / i, t_att / i

    '''
    def train_dataloader(self, batch_size):
        return DataLoader(AudioDataset(HPText.dataset, start_idx=0, end_idx=HPText.num_train, durations=False), batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=False) #여기 원래 true인데 시험해보려고 바꾼 부분
                          
    def val_dataloader(self, batch_size):
        dataset = AudioDataset(HPText.dataset, start_idx=HPText.num_train, end_idx=HPText.num_valid, durations=False)
        return DataLoader(dataset, batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=False, sampler=SequentialSampler(dataset))
    '''

    def train_dataloader(self, batch_size):
        return DataLoader(K_AudioDataset(HPText.k_dataset,
                                         start_idx=0,
                                         end_idx=HPText.k_num_train,
                                         durations=False),
                          batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=True)  #여기 원래 true인데 시험해보려고 바꾼 부분

    def val_dataloader(self, batch_size):
        dataset = K_AudioDataset(HPText.k_dataset,
                                 start_idx=HPText.k_num_train,
                                 end_idx=HPText.k_num_valid,
                                 durations=False)
        return DataLoader(dataset,
                          batch_size=batch_size,
                          collate_fn=self.collate,
                          shuffle=False,
                          sampler=SequentialSampler(dataset))