Exemplo n.º 1
0
 def console_log(self, tag: str, meta: Dict[str, Any], step: int):
     # console logging
     msg = '{}\t{:06d} it'.format(tag, step)
     for key, (value, log_type) in sorted(meta.items()):
         if log_type == LogType.SCALAR:
             msg += '\t{}: {:.6f}'.format(key, value)
     log(msg)
Exemplo n.º 2
0
    def load(self, load_optim: bool = True):
        # make name
        save_name = self.save_name

        # save path
        save_path = os.path.join(self.model_dir, save_name)

        # get latest file
        check_files = glob.glob(os.path.join(save_path, '*'))
        if check_files:
            # load latest state dict
            latest_file = max(check_files, key=os.path.getctime)
            state_dict = torch.load(latest_file)
            if 'seed' in state_dict:
                self.seed = state_dict['seed']
            # load model
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(
                    get_loadable_checkpoint(state_dict['model']))
            else:
                self.model.load_state_dict(
                    get_loadable_checkpoint(state_dict['model']))
            if load_optim:
                self.optimizer.load_state_dict(state_dict['optim'])
            if self.scheduler is not None:
                self.scheduler.load_state_dict(state_dict['scheduler'])
            self.step = state_dict['step']
            log('checkpoint \'{}\' is loaded. previous step={}'.format(
                latest_file, self.step))
        else:
            log('No any checkpoint in {}. Loading network skipped.'.format(
                save_path))
Exemplo n.º 3
0
    def save(self, step: int):

        # state dict
        state_dict = get_loadable_checkpoint(self.model.state_dict())

        # train
        state_dict = {
            'step': step,
            'model': state_dict,
            'optim': self.optimizer.state_dict(),
            'pretrained_step': step,
            'seed': self.seed
        }
        if self.scheduler is not None:
            state_dict.update({'scheduler': self.scheduler.state_dict()})

        # save for training
        save_name = self.save_name

        save_path = os.path.join(self.model_dir, save_name)
        os.makedirs(save_path, exist_ok=True)
        torch.save(state_dict,
                   os.path.join(save_path, 'step_{:06d}.chkpt'.format(step)))

        # save best
        if self.best_valid_loss != self.cur_best_valid_loss:
            save_path = os.path.join(self.model_dir, save_name + '.best.chkpt')
            torch.save(state_dict, save_path)
            self.cur_best_valid_loss = self.best_valid_loss

        # logging
        log('step %d / saved model.' % step)
Exemplo n.º 4
0
    def train(self, step: int) -> torch.Tensor:

        # update model
        self.optimizer.zero_grad()

        # flag for logging
        log_flag = step % self.log_interval == 0

        # forward model
        loss, meta = self.forward(*to_device(next(self.train_dataset)),
                                  log_flag)

        # check loss nan
        if loss != loss:
            log('{} cur step NAN is occured'.format(step))
            return

        loss.backward()
        self.clip_grad()
        self.optimizer.step()

        # logging
        if log_flag:
            # console logging
            self.console_log('train', meta, step)
            # tensorboard logging
            self.tensorboard_log('train', meta, step)
Exemplo n.º 5
0
    def validate(self, step: int):

        loss = 0.
        stat = defaultdict(float)

        for i in range(self.valid_max_step):
            # flag for logging
            log_flag = i % self.log_interval == 0 or i == self.valid_max_step - 1

            # forward model
            with torch.no_grad():
                batch_loss, meta = self.forward(*to_device(next(self.valid_dataset)), is_logging=log_flag)
                loss += batch_loss

            # update stat
            for key, (value, log_type) in meta.items():
                if log_type == LogType.SCALAR:
                    stat[key] += value

            # console logging of this step
            if (i + 1) % self.log_interval == 0:
                self.console_log('valid', meta, i + 1)

        meta_non_scalar = {
            key: (value, log_type) for key, (value, log_type) in meta.items()
            if not log_type == LogType.SCALAR
        }

        try:
            self.tensorboard_log('valid', meta_non_scalar, step)
        except OverflowError:
            pass

        # averaging stat
        loss /= self.valid_max_step
        for key in stat.keys():
            stat[key] = stat[key] / self.valid_max_step

        # update best valid loss
        if loss < self.best_valid_loss:
            self.best_valid_loss = loss

        # console logging of total stat
        msg = 'step {} / total stat'.format(step)
        for key, value in sorted(stat.items()):
            msg += '\t{}: {:.6f}'.format(key, value)
        log(msg)

        # tensor board logging of scalar stat
        for key, value in stat.items():
            self.writer.add_scalar('valid/{}'.format(key), value, global_step=step)
Exemplo n.º 6
0
    def validate(self, step: int):

        loss = 0.
        count = 0
        stat = defaultdict(float)

        for i in range(self.valid_max_step):
            # forward model
            with torch.no_grad():
                batch_loss, meta = self.forward(*to_device(
                    next(self.valid_dataset)),
                                                is_logging=True)
                loss += batch_loss

            for key, (value, log_type) in meta.items():
                if log_type == LogType.SCALAR:
                    stat[key] += value

            if i % self.log_interval == 0 or i == self.valid_max_step - 1:
                self.console_log('valid', meta, i + 1)

        # averaging stat
        loss /= self.valid_max_step
        for key in stat.keys():
            if key == 'loss':
                continue
            stat[key] = stat[key] / self.valid_max_step
        stat['loss'] = loss

        # update best valid loss
        if loss < self.best_valid_loss:
            self.best_valid_loss = loss

        # console logging of total stat
        msg = 'step {} / total stat'.format(step)
        for key, value in sorted(stat.items()):
            msg += '\t{}: {:.6f}'.format(key, value)
        log(msg)

        # tensor board logging of scalar stat
        for key, value in stat.items():
            self.writer.add_scalar('valid/{}'.format(key),
                                   value,
                                   global_step=step)
Exemplo n.º 7
0
    def run(self) -> float:
        try:
            # training loop
            for i in range(self.step + 1, self.max_step + 1):

                # update step
                self.step = i

                # logging
                if i % self.save_interval == 1:
                    log('------------- TRAIN step : %d -------------' % i)

                # do training step
                if self.scheduler is not None:
                    self.scheduler.step(i)
                self.model.train()
                self.train(i)

                # save model
                if i % self.save_interval == 0:
                    log('------------- VALID step : %d -------------' % i)
                    # valid
                    self.model.eval()
                    self.validate(i)
                    # save model checkpoint file
                    self.save(i)

        except KeyboardInterrupt:
            log('Train is canceled !!')

        return self.best_valid_loss
Exemplo n.º 8
0
    def __init__(self,
                 model: nn.Module,
                 optimizer: torch.optim.Optimizer,
                 train_dataset,
                 valid_dataset,
                 max_step: int,
                 valid_max_step: int,
                 save_interval: int,
                 log_interval: int,
                 save_dir: str,
                 save_prefix: str = 'save',
                 grad_clip: float = 0.0,
                 grad_norm: float = 0.0,
                 pretrained_path: str = None,
                 sr: int = None,
                 scheduler: torch.optim.lr_scheduler._LRScheduler = None):

        # save project info
        self.pretrained_trained = pretrained_path

        # model
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler

        # log how many parameters in the model
        n_params = sum(p.numel() for p in self.model.parameters()
                       if p.requires_grad)
        log('Model {} was loaded. Total {} params.'.format(
            self.model.__class__.__name__, n_params))

        # adopt repeating function on datasets
        self.train_dataset = self.repeat(train_dataset)
        self.valid_dataset = self.repeat(valid_dataset)

        # save parameters
        self.step = 0
        if sr:
            self.sr = sr
        else:
            self.sr = SAMPLE_RATE
        self.max_step = max_step
        self.save_interval = save_interval
        self.log_interval = log_interval
        self.save_dir = save_dir
        self.save_prefix = save_prefix
        self.grad_clip = grad_clip
        self.grad_norm = grad_norm
        self.valid_max_step = valid_max_step

        # make dirs
        self.log_dir = os.path.join(save_dir, 'logs', self.save_prefix)
        self.model_dir = os.path.join(save_dir, 'models')
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.model_dir, exist_ok=True)

        self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)

        # load previous checkpoint
        # set seed
        self.seed = None
        self.load()

        if not self.seed:
            self.seed = np.random.randint(np.iinfo(np.int32).max)
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            torch.cuda.manual_seed(self.seed)

        # load pretrained model
        if self.step == 0 and pretrained_path:
            self.load_pretrained_model()

        # valid loss
        self.best_valid_loss = np.finfo(np.float32).max
        self.cur_best_valid_loss = self.best_valid_loss
        self.save_valid_loss = np.finfo(np.float32).max
Exemplo n.º 9
0
def main(meta_dir: str,
         save_dir: str,
         save_prefix: str,
         pretrained_path: str = '',
         batch_size: int = 32,
         num_workers: int = 8,
         lr: float = 1e-4,
         betas: Tuple[float, float] = (0.5, 0.9),
         weight_decay: float = 0.0,
         pretrain_step: int = 200000,
         max_step: int = 1000000,
         save_interval: int = 10000,
         log_scala_interval: int = 20,
         log_heavy_interval: int = 1000,
         gamma: float = 0.5,
         seed: int = 1234):
    #
    # prepare training
    #
    # create model
    mb_generator = build_model('generator_mb').cuda()
    discriminator = build_model('discriminator_base').cuda()

    # Multi-gpu is not required.

    # create optimizers
    mb_opt = torch.optim.Adam(mb_generator.parameters(),
                              lr=lr,
                              betas=betas,
                              weight_decay=weight_decay)
    dis_opt = torch.optim.Adam(discriminator.parameters(),
                               lr=lr,
                               betas=betas,
                               weight_decay=weight_decay)

    # make scheduler
    mb_scheduler = MultiStepLR(mb_opt,
                               list(range(300000, 900000 + 1, 100000)),
                               gamma=gamma)
    dis_scheduler = MultiStepLR(dis_opt,
                                list(range(100000, 700000 + 1, 100000)),
                                gamma=gamma)

    # get datasets
    train_loader, valid_loader = get_datasets(meta_dir,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              crop_length=settings.SAMPLE_RATE,
                                              random_seed=seed)

    # repeat
    train_loader = repeat(train_loader)

    # build mel function
    mel_func, stft_funcs_for_loss = build_stft_functions()

    # build pqmf
    pqmf_func = PQMF().cuda()

    # prepare logging
    writer, model_dir = prepare_logging(save_dir, save_prefix)

    # Training Saving Attributes
    best_loss = np.finfo(np.float32).max
    initial_step = 0

    # load model
    if pretrained_path:
        log(f'Pretrained path is given : {pretrained_path} . Loading...')
        chk = torch.load(pretrained_path)
        gen_chk, dis_chk = chk['generator'], chk['discriminator']
        gen_opt_chk, dis_opt_chk = chk['gen_opt'], chk['dis_opt']
        initial_step = int(chk['step'])
        l = chk['loss']

        mb_generator.load_state_dict(gen_chk)
        discriminator.load_state_dict(dis_chk)
        mb_opt.load_state_dict(gen_opt_chk)
        dis_opt.load_state_dict(dis_opt_chk)
        if 'dis_scheduler' in chk:
            dis_scheduler_chk = chk['dis_scheduler']
            gen_scheduler_chk = chk['gen_scheduler']
            mb_scheduler.load_state_dict(gen_scheduler_chk)
            dis_scheduler.load_state_dict(dis_scheduler_chk)

        mb_opt._step_count = initial_step
        mb_scheduler._step_count = initial_step
        dis_opt._step_count = initial_step - pretrain_step
        dis_scheduler._step_count = initial_step - pretrain_step

        mb_scheduler.step(initial_step)
        dis_scheduler.step(initial_step - pretrain_step)
        best_loss = l

    #
    # Training !
    #
    # Pretraining generator
    for step in range(initial_step, pretrain_step):
        # data
        wav, _ = next(train_loader)
        wav = wav.cuda()

        # to mel
        mel = mel_func(wav)

        # pqmf
        target_subbands = pqmf_func.analysis(wav.unsqueeze(1))  # N, SUBBAND, T

        # forward
        pred_subbands = mb_generator(mel)
        pred_subbands, _ = match_dim(pred_subbands, target_subbands)

        # pqmf synthesis
        pred = pqmf_func.synthesis(pred_subbands)
        pred, wav = match_dim(pred, wav)

        # get multi-resolution stft loss   eq 9)
        loss, mb_loss, fb_loss = get_stft_loss(pred, wav, pred_subbands,
                                               target_subbands,
                                               stft_funcs_for_loss)

        # backward and update
        loss.backward()
        mb_opt.step()
        mb_scheduler.step()

        mb_opt.zero_grad()
        mb_generator.zero_grad()

        #
        # logging! save!
        #
        if step % log_scala_interval == 0 and step > 0:
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('train/pretrain_loss',
                              loss.item(),
                              global_step=step)
            writer.add_scalar('train/mb_loss',
                              mb_loss.item(),
                              global_step=step)
            writer.add_scalar('train/fb_loss',
                              fb_loss.item(),
                              global_step=step)

            if step % log_heavy_interval == 0:
                writer.add_audio('train/pred_audio',
                                 pred_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)
                writer.add_audio('train/target_audio',
                                 target_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)

            # console
            msg = f'train: step: {step} / loss: {loss.item()} / mb_loss: {mb_loss.item()} / fb_loss: {fb_loss.item()}'
            log(msg)

        if step % save_interval == 0 and step > 0:
            #
            # Validation Step !
            #
            valid_loss = 0.
            valid_mb_loss, valid_fb_loss = 0., 0.
            count = 0
            mb_generator.eval()

            for idx, (wav, _) in enumerate(valid_loader):
                # setup data
                wav = wav.cuda()
                mel = mel_func(wav)

                with torch.no_grad():
                    # pqmf
                    target_subbands = pqmf_func.analysis(
                        wav.unsqueeze(1))  # N, SUBBAND, T

                    # forward
                    pred_subbands = mb_generator(mel)
                    pred_subbands, _ = match_dim(pred_subbands,
                                                 target_subbands)

                    # pqmf synthesis
                    pred = pqmf_func.synthesis(pred_subbands)
                    pred, wav = match_dim(pred, wav)

                    # get stft loss
                    loss, mb_loss, fb_loss = get_stft_loss(
                        pred, wav, pred_subbands, target_subbands,
                        stft_funcs_for_loss)

                valid_loss += loss.item()
                valid_mb_loss += mb_loss.item()
                valid_fb_loss += fb_loss.item()
                count = idx

            valid_loss /= (count + 1)
            valid_mb_loss /= (count + 1)
            valid_fb_loss /= (count + 1)
            mb_generator.train()

            # log validation
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('valid/pretrain_loss',
                              valid_loss,
                              global_step=step)
            writer.add_scalar('valid/mb_loss', valid_mb_loss, global_step=step)
            writer.add_scalar('valid/fb_loss', valid_fb_loss, global_step=step)
            writer.add_audio('valid/pred_audio',
                             pred_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)
            writer.add_audio('valid/target_audio',
                             target_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)

            # console
            log(f'---- Valid loss: {valid_loss} / mb_loss: {valid_mb_loss} / fb_loss: {valid_fb_loss} ----'
                )

            #
            # save checkpoint
            #
            is_best = valid_loss < best_loss
            if is_best:
                best_loss = valid_loss
            save_checkpoint(mb_generator,
                            discriminator,
                            mb_opt,
                            dis_opt,
                            mb_scheduler,
                            dis_scheduler,
                            model_dir,
                            step,
                            valid_loss,
                            is_best=is_best)

    #
    # Train GAN
    #
    dis_block_layers = 6
    lambda_gen = 2.5
    best_loss = np.finfo(np.float32).max

    for step in range(max(pretrain_step, initial_step), max_step):

        # data
        wav, _ = next(train_loader)
        wav = wav.cuda()

        # to mel
        mel = mel_func(wav)

        # pqmf
        target_subbands = pqmf_func.analysis(wav.unsqueeze(1))  # N, SUBBAND, T

        #
        # Train Discriminator
        #

        # forward
        pred_subbands = mb_generator(mel)
        pred_subbands, _ = match_dim(pred_subbands, target_subbands)

        # pqmf synthesis
        pred = pqmf_func.synthesis(pred_subbands)
        pred, wav = match_dim(pred, wav)

        with torch.no_grad():
            pred_mel = mel_func(pred.squeeze(1).detach())
            mel_err = F.l1_loss(mel, pred_mel).item()

        # if terminate_step > step:
        d_fake_det = discriminator(pred.detach())
        d_real = discriminator(wav.unsqueeze(1))

        # calculate discriminator losses  eq 1)
        loss_D = 0

        for idx in range(dis_block_layers - 1, len(d_fake_det),
                         dis_block_layers):
            loss_D += torch.mean((d_fake_det[idx] - 1)**2)

        for idx in range(dis_block_layers - 1, len(d_real), dis_block_layers):
            loss_D += torch.mean(d_real[idx]**2)

        # train
        discriminator.zero_grad()
        loss_D.backward()
        dis_opt.step()
        dis_scheduler.step()

        #
        # Train Generator
        #
        d_fake = discriminator(pred)

        # calc generator loss   eq 8)
        loss_G = 0
        for idx in range(dis_block_layers - 1, len(d_fake), dis_block_layers):
            loss_G += ((d_fake[idx] - 1)**2).mean()

        loss_G *= lambda_gen

        # get multi-resolution stft loss
        loss_G += get_stft_loss(pred, wav, pred_subbands, target_subbands,
                                stft_funcs_for_loss)[0]
        # loss_G += get_spec_losses(pred, wav, stft_funcs_for_loss)[0]

        mb_generator.zero_grad()
        loss_G.backward()
        mb_opt.step()
        mb_scheduler.step()

        #
        # logging! save!
        #
        if step % log_scala_interval == 0 and step > 0:
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('train/loss_G', loss_G.item(), global_step=step)
            writer.add_scalar('train/loss_D', loss_D.item(), global_step=step)
            writer.add_scalar('train/mel_err', mel_err, global_step=step)
            if step % log_heavy_interval == 0:
                target_mel = imshow_to_buf(mel[0].detach().cpu().numpy())
                pred_mel = imshow_to_buf(
                    mel_func(pred[:1, 0])[0].detach().cpu().numpy())

                writer.add_image('train/target_mel',
                                 target_mel,
                                 global_step=step)
                writer.add_image('train/pred_mel', pred_mel, global_step=step)
                writer.add_audio('train/pred_audio',
                                 pred_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)
                writer.add_audio('train/target_audio',
                                 target_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)

            # console
            msg = f'train: step: {step} / loss_G: {loss_G.item()} / loss_D: {loss_D.item()} / ' \
                f' mel_err: {mel_err}'
            log(msg)

        if step % save_interval == 0 and step > 0:
            #
            # Validation Step !
            #
            valid_g_loss, valid_d_loss, valid_mel_loss = 0., 0., 0.
            count = 0
            mb_generator.eval()
            discriminator.eval()

            for idx, (wav, _) in enumerate(valid_loader):
                # setup data
                wav = wav.cuda()
                mel = mel_func(wav)

                with torch.no_grad():
                    # pqmf
                    target_subbands = pqmf_func.analysis(
                        wav.unsqueeze(1))  # N, SUBBAND, T

                    # Discriminator
                    pred_subbands = mb_generator(mel)
                    pred_subbands, _ = match_dim(pred_subbands,
                                                 target_subbands)

                    # pqmf synthesis
                    pred = pqmf_func.synthesis(pred_subbands)
                    pred, wav = match_dim(pred, wav)

                    # Mel Error
                    pred_mel = mel_func(pred.squeeze(1).detach())
                    mel_err = F.l1_loss(mel, pred_mel).item()

                    #
                    # discriminator part
                    #
                    d_fake_det = discriminator(pred.detach())
                    d_real = discriminator(wav.unsqueeze(1))

                    loss_D = 0

                    for idx in range(dis_block_layers - 1, len(d_fake_det),
                                     dis_block_layers):
                        loss_D += torch.mean((d_fake_det[idx] - 1)**2)

                    for idx in range(dis_block_layers - 1, len(d_real),
                                     dis_block_layers):
                        loss_D += torch.mean(d_real[idx]**2)

                    #
                    # generator part
                    #
                    d_fake = discriminator(pred)

                    # calc generator loss
                    loss_G = 0
                    for idx in range(dis_block_layers - 1, len(d_fake),
                                     dis_block_layers):
                        loss_G += ((d_fake[idx] - 1)**2).mean()

                    loss_G *= lambda_gen

                    # get stft loss
                    stft_loss = get_stft_loss(pred, wav, pred_subbands,
                                              target_subbands,
                                              stft_funcs_for_loss)[0]
                    loss_G += stft_loss

                valid_d_loss += loss_D.item()
                valid_g_loss += loss_G.item()
                valid_mel_loss += mel_err
                count = idx

            valid_d_loss /= (count + 1)
            valid_g_loss /= (count + 1)
            valid_mel_loss /= (count + 1)

            mb_generator.train()
            discriminator.train()

            # log validation
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            target_mel = imshow_to_buf(mel[0].detach().cpu().numpy())
            pred_mel = imshow_to_buf(
                mel_func(pred[:1, 0])[0].detach().cpu().numpy())

            writer.add_image('valid/target_mel', target_mel, global_step=step)
            writer.add_image('valid/pred_mel', pred_mel, global_step=step)
            writer.add_scalar('valid/loss_G', valid_g_loss, global_step=step)
            writer.add_scalar('valid/loss_D', valid_d_loss, global_step=step)
            writer.add_scalar('valid/mel_err',
                              valid_mel_loss,
                              global_step=step)
            writer.add_audio('valid/pred_audio',
                             pred_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)
            writer.add_audio('valid/target_audio',
                             target_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)

            # console
            log(f'---- loss_G: {valid_g_loss} / loss_D: {valid_d_loss} / mel loss : {valid_mel_loss} ----'
                )

            #
            # save checkpoint
            #
            is_best = valid_g_loss < best_loss
            if is_best:
                best_loss = valid_g_loss
            save_checkpoint(mb_generator,
                            discriminator,
                            mb_opt,
                            dis_opt,
                            mb_scheduler,
                            dis_scheduler,
                            model_dir,
                            step,
                            valid_g_loss,
                            is_best=is_best)

    log('----- Finish ! -----')