def console_log(self, tag: str, meta: Dict[str, Any], step: int): # console logging msg = '{}\t{:06d} it'.format(tag, step) for key, (value, log_type) in sorted(meta.items()): if log_type == LogType.SCALAR: msg += '\t{}: {:.6f}'.format(key, value) log(msg)
def load(self, load_optim: bool = True): # make name save_name = self.save_name # save path save_path = os.path.join(self.model_dir, save_name) # get latest file check_files = glob.glob(os.path.join(save_path, '*')) if check_files: # load latest state dict latest_file = max(check_files, key=os.path.getctime) state_dict = torch.load(latest_file) if 'seed' in state_dict: self.seed = state_dict['seed'] # load model if isinstance(self.model, nn.DataParallel): self.model.module.load_state_dict( get_loadable_checkpoint(state_dict['model'])) else: self.model.load_state_dict( get_loadable_checkpoint(state_dict['model'])) if load_optim: self.optimizer.load_state_dict(state_dict['optim']) if self.scheduler is not None: self.scheduler.load_state_dict(state_dict['scheduler']) self.step = state_dict['step'] log('checkpoint \'{}\' is loaded. previous step={}'.format( latest_file, self.step)) else: log('No any checkpoint in {}. Loading network skipped.'.format( save_path))
def save(self, step: int): # state dict state_dict = get_loadable_checkpoint(self.model.state_dict()) # train state_dict = { 'step': step, 'model': state_dict, 'optim': self.optimizer.state_dict(), 'pretrained_step': step, 'seed': self.seed } if self.scheduler is not None: state_dict.update({'scheduler': self.scheduler.state_dict()}) # save for training save_name = self.save_name save_path = os.path.join(self.model_dir, save_name) os.makedirs(save_path, exist_ok=True) torch.save(state_dict, os.path.join(save_path, 'step_{:06d}.chkpt'.format(step))) # save best if self.best_valid_loss != self.cur_best_valid_loss: save_path = os.path.join(self.model_dir, save_name + '.best.chkpt') torch.save(state_dict, save_path) self.cur_best_valid_loss = self.best_valid_loss # logging log('step %d / saved model.' % step)
def train(self, step: int) -> torch.Tensor: # update model self.optimizer.zero_grad() # flag for logging log_flag = step % self.log_interval == 0 # forward model loss, meta = self.forward(*to_device(next(self.train_dataset)), log_flag) # check loss nan if loss != loss: log('{} cur step NAN is occured'.format(step)) return loss.backward() self.clip_grad() self.optimizer.step() # logging if log_flag: # console logging self.console_log('train', meta, step) # tensorboard logging self.tensorboard_log('train', meta, step)
def validate(self, step: int): loss = 0. stat = defaultdict(float) for i in range(self.valid_max_step): # flag for logging log_flag = i % self.log_interval == 0 or i == self.valid_max_step - 1 # forward model with torch.no_grad(): batch_loss, meta = self.forward(*to_device(next(self.valid_dataset)), is_logging=log_flag) loss += batch_loss # update stat for key, (value, log_type) in meta.items(): if log_type == LogType.SCALAR: stat[key] += value # console logging of this step if (i + 1) % self.log_interval == 0: self.console_log('valid', meta, i + 1) meta_non_scalar = { key: (value, log_type) for key, (value, log_type) in meta.items() if not log_type == LogType.SCALAR } try: self.tensorboard_log('valid', meta_non_scalar, step) except OverflowError: pass # averaging stat loss /= self.valid_max_step for key in stat.keys(): stat[key] = stat[key] / self.valid_max_step # update best valid loss if loss < self.best_valid_loss: self.best_valid_loss = loss # console logging of total stat msg = 'step {} / total stat'.format(step) for key, value in sorted(stat.items()): msg += '\t{}: {:.6f}'.format(key, value) log(msg) # tensor board logging of scalar stat for key, value in stat.items(): self.writer.add_scalar('valid/{}'.format(key), value, global_step=step)
def validate(self, step: int): loss = 0. count = 0 stat = defaultdict(float) for i in range(self.valid_max_step): # forward model with torch.no_grad(): batch_loss, meta = self.forward(*to_device( next(self.valid_dataset)), is_logging=True) loss += batch_loss for key, (value, log_type) in meta.items(): if log_type == LogType.SCALAR: stat[key] += value if i % self.log_interval == 0 or i == self.valid_max_step - 1: self.console_log('valid', meta, i + 1) # averaging stat loss /= self.valid_max_step for key in stat.keys(): if key == 'loss': continue stat[key] = stat[key] / self.valid_max_step stat['loss'] = loss # update best valid loss if loss < self.best_valid_loss: self.best_valid_loss = loss # console logging of total stat msg = 'step {} / total stat'.format(step) for key, value in sorted(stat.items()): msg += '\t{}: {:.6f}'.format(key, value) log(msg) # tensor board logging of scalar stat for key, value in stat.items(): self.writer.add_scalar('valid/{}'.format(key), value, global_step=step)
def run(self) -> float: try: # training loop for i in range(self.step + 1, self.max_step + 1): # update step self.step = i # logging if i % self.save_interval == 1: log('------------- TRAIN step : %d -------------' % i) # do training step if self.scheduler is not None: self.scheduler.step(i) self.model.train() self.train(i) # save model if i % self.save_interval == 0: log('------------- VALID step : %d -------------' % i) # valid self.model.eval() self.validate(i) # save model checkpoint file self.save(i) except KeyboardInterrupt: log('Train is canceled !!') return self.best_valid_loss
def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, train_dataset, valid_dataset, max_step: int, valid_max_step: int, save_interval: int, log_interval: int, save_dir: str, save_prefix: str = 'save', grad_clip: float = 0.0, grad_norm: float = 0.0, pretrained_path: str = None, sr: int = None, scheduler: torch.optim.lr_scheduler._LRScheduler = None): # save project info self.pretrained_trained = pretrained_path # model self.model = model self.optimizer = optimizer self.scheduler = scheduler # log how many parameters in the model n_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) log('Model {} was loaded. Total {} params.'.format( self.model.__class__.__name__, n_params)) # adopt repeating function on datasets self.train_dataset = self.repeat(train_dataset) self.valid_dataset = self.repeat(valid_dataset) # save parameters self.step = 0 if sr: self.sr = sr else: self.sr = SAMPLE_RATE self.max_step = max_step self.save_interval = save_interval self.log_interval = log_interval self.save_dir = save_dir self.save_prefix = save_prefix self.grad_clip = grad_clip self.grad_norm = grad_norm self.valid_max_step = valid_max_step # make dirs self.log_dir = os.path.join(save_dir, 'logs', self.save_prefix) self.model_dir = os.path.join(save_dir, 'models') os.makedirs(self.log_dir, exist_ok=True) os.makedirs(self.model_dir, exist_ok=True) self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) # load previous checkpoint # set seed self.seed = None self.load() if not self.seed: self.seed = np.random.randint(np.iinfo(np.int32).max) np.random.seed(self.seed) torch.manual_seed(self.seed) torch.cuda.manual_seed(self.seed) # load pretrained model if self.step == 0 and pretrained_path: self.load_pretrained_model() # valid loss self.best_valid_loss = np.finfo(np.float32).max self.cur_best_valid_loss = self.best_valid_loss self.save_valid_loss = np.finfo(np.float32).max
def main(meta_dir: str, save_dir: str, save_prefix: str, pretrained_path: str = '', batch_size: int = 32, num_workers: int = 8, lr: float = 1e-4, betas: Tuple[float, float] = (0.5, 0.9), weight_decay: float = 0.0, pretrain_step: int = 200000, max_step: int = 1000000, save_interval: int = 10000, log_scala_interval: int = 20, log_heavy_interval: int = 1000, gamma: float = 0.5, seed: int = 1234): # # prepare training # # create model mb_generator = build_model('generator_mb').cuda() discriminator = build_model('discriminator_base').cuda() # Multi-gpu is not required. # create optimizers mb_opt = torch.optim.Adam(mb_generator.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) dis_opt = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) # make scheduler mb_scheduler = MultiStepLR(mb_opt, list(range(300000, 900000 + 1, 100000)), gamma=gamma) dis_scheduler = MultiStepLR(dis_opt, list(range(100000, 700000 + 1, 100000)), gamma=gamma) # get datasets train_loader, valid_loader = get_datasets(meta_dir, batch_size=batch_size, num_workers=num_workers, crop_length=settings.SAMPLE_RATE, random_seed=seed) # repeat train_loader = repeat(train_loader) # build mel function mel_func, stft_funcs_for_loss = build_stft_functions() # build pqmf pqmf_func = PQMF().cuda() # prepare logging writer, model_dir = prepare_logging(save_dir, save_prefix) # Training Saving Attributes best_loss = np.finfo(np.float32).max initial_step = 0 # load model if pretrained_path: log(f'Pretrained path is given : {pretrained_path} . Loading...') chk = torch.load(pretrained_path) gen_chk, dis_chk = chk['generator'], chk['discriminator'] gen_opt_chk, dis_opt_chk = chk['gen_opt'], chk['dis_opt'] initial_step = int(chk['step']) l = chk['loss'] mb_generator.load_state_dict(gen_chk) discriminator.load_state_dict(dis_chk) mb_opt.load_state_dict(gen_opt_chk) dis_opt.load_state_dict(dis_opt_chk) if 'dis_scheduler' in chk: dis_scheduler_chk = chk['dis_scheduler'] gen_scheduler_chk = chk['gen_scheduler'] mb_scheduler.load_state_dict(gen_scheduler_chk) dis_scheduler.load_state_dict(dis_scheduler_chk) mb_opt._step_count = initial_step mb_scheduler._step_count = initial_step dis_opt._step_count = initial_step - pretrain_step dis_scheduler._step_count = initial_step - pretrain_step mb_scheduler.step(initial_step) dis_scheduler.step(initial_step - pretrain_step) best_loss = l # # Training ! # # Pretraining generator for step in range(initial_step, pretrain_step): # data wav, _ = next(train_loader) wav = wav.cuda() # to mel mel = mel_func(wav) # pqmf target_subbands = pqmf_func.analysis(wav.unsqueeze(1)) # N, SUBBAND, T # forward pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) # get multi-resolution stft loss eq 9) loss, mb_loss, fb_loss = get_stft_loss(pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss) # backward and update loss.backward() mb_opt.step() mb_scheduler.step() mb_opt.zero_grad() mb_generator.zero_grad() # # logging! save! # if step % log_scala_interval == 0 and step > 0: # log writer pred_audio = pred[0, 0] target_audio = wav[0] writer.add_scalar('train/pretrain_loss', loss.item(), global_step=step) writer.add_scalar('train/mb_loss', mb_loss.item(), global_step=step) writer.add_scalar('train/fb_loss', fb_loss.item(), global_step=step) if step % log_heavy_interval == 0: writer.add_audio('train/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('train/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console msg = f'train: step: {step} / loss: {loss.item()} / mb_loss: {mb_loss.item()} / fb_loss: {fb_loss.item()}' log(msg) if step % save_interval == 0 and step > 0: # # Validation Step ! # valid_loss = 0. valid_mb_loss, valid_fb_loss = 0., 0. count = 0 mb_generator.eval() for idx, (wav, _) in enumerate(valid_loader): # setup data wav = wav.cuda() mel = mel_func(wav) with torch.no_grad(): # pqmf target_subbands = pqmf_func.analysis( wav.unsqueeze(1)) # N, SUBBAND, T # forward pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) # get stft loss loss, mb_loss, fb_loss = get_stft_loss( pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss) valid_loss += loss.item() valid_mb_loss += mb_loss.item() valid_fb_loss += fb_loss.item() count = idx valid_loss /= (count + 1) valid_mb_loss /= (count + 1) valid_fb_loss /= (count + 1) mb_generator.train() # log validation # log writer pred_audio = pred[0, 0] target_audio = wav[0] writer.add_scalar('valid/pretrain_loss', valid_loss, global_step=step) writer.add_scalar('valid/mb_loss', valid_mb_loss, global_step=step) writer.add_scalar('valid/fb_loss', valid_fb_loss, global_step=step) writer.add_audio('valid/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('valid/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console log(f'---- Valid loss: {valid_loss} / mb_loss: {valid_mb_loss} / fb_loss: {valid_fb_loss} ----' ) # # save checkpoint # is_best = valid_loss < best_loss if is_best: best_loss = valid_loss save_checkpoint(mb_generator, discriminator, mb_opt, dis_opt, mb_scheduler, dis_scheduler, model_dir, step, valid_loss, is_best=is_best) # # Train GAN # dis_block_layers = 6 lambda_gen = 2.5 best_loss = np.finfo(np.float32).max for step in range(max(pretrain_step, initial_step), max_step): # data wav, _ = next(train_loader) wav = wav.cuda() # to mel mel = mel_func(wav) # pqmf target_subbands = pqmf_func.analysis(wav.unsqueeze(1)) # N, SUBBAND, T # # Train Discriminator # # forward pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) with torch.no_grad(): pred_mel = mel_func(pred.squeeze(1).detach()) mel_err = F.l1_loss(mel, pred_mel).item() # if terminate_step > step: d_fake_det = discriminator(pred.detach()) d_real = discriminator(wav.unsqueeze(1)) # calculate discriminator losses eq 1) loss_D = 0 for idx in range(dis_block_layers - 1, len(d_fake_det), dis_block_layers): loss_D += torch.mean((d_fake_det[idx] - 1)**2) for idx in range(dis_block_layers - 1, len(d_real), dis_block_layers): loss_D += torch.mean(d_real[idx]**2) # train discriminator.zero_grad() loss_D.backward() dis_opt.step() dis_scheduler.step() # # Train Generator # d_fake = discriminator(pred) # calc generator loss eq 8) loss_G = 0 for idx in range(dis_block_layers - 1, len(d_fake), dis_block_layers): loss_G += ((d_fake[idx] - 1)**2).mean() loss_G *= lambda_gen # get multi-resolution stft loss loss_G += get_stft_loss(pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss)[0] # loss_G += get_spec_losses(pred, wav, stft_funcs_for_loss)[0] mb_generator.zero_grad() loss_G.backward() mb_opt.step() mb_scheduler.step() # # logging! save! # if step % log_scala_interval == 0 and step > 0: # log writer pred_audio = pred[0, 0] target_audio = wav[0] writer.add_scalar('train/loss_G', loss_G.item(), global_step=step) writer.add_scalar('train/loss_D', loss_D.item(), global_step=step) writer.add_scalar('train/mel_err', mel_err, global_step=step) if step % log_heavy_interval == 0: target_mel = imshow_to_buf(mel[0].detach().cpu().numpy()) pred_mel = imshow_to_buf( mel_func(pred[:1, 0])[0].detach().cpu().numpy()) writer.add_image('train/target_mel', target_mel, global_step=step) writer.add_image('train/pred_mel', pred_mel, global_step=step) writer.add_audio('train/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('train/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console msg = f'train: step: {step} / loss_G: {loss_G.item()} / loss_D: {loss_D.item()} / ' \ f' mel_err: {mel_err}' log(msg) if step % save_interval == 0 and step > 0: # # Validation Step ! # valid_g_loss, valid_d_loss, valid_mel_loss = 0., 0., 0. count = 0 mb_generator.eval() discriminator.eval() for idx, (wav, _) in enumerate(valid_loader): # setup data wav = wav.cuda() mel = mel_func(wav) with torch.no_grad(): # pqmf target_subbands = pqmf_func.analysis( wav.unsqueeze(1)) # N, SUBBAND, T # Discriminator pred_subbands = mb_generator(mel) pred_subbands, _ = match_dim(pred_subbands, target_subbands) # pqmf synthesis pred = pqmf_func.synthesis(pred_subbands) pred, wav = match_dim(pred, wav) # Mel Error pred_mel = mel_func(pred.squeeze(1).detach()) mel_err = F.l1_loss(mel, pred_mel).item() # # discriminator part # d_fake_det = discriminator(pred.detach()) d_real = discriminator(wav.unsqueeze(1)) loss_D = 0 for idx in range(dis_block_layers - 1, len(d_fake_det), dis_block_layers): loss_D += torch.mean((d_fake_det[idx] - 1)**2) for idx in range(dis_block_layers - 1, len(d_real), dis_block_layers): loss_D += torch.mean(d_real[idx]**2) # # generator part # d_fake = discriminator(pred) # calc generator loss loss_G = 0 for idx in range(dis_block_layers - 1, len(d_fake), dis_block_layers): loss_G += ((d_fake[idx] - 1)**2).mean() loss_G *= lambda_gen # get stft loss stft_loss = get_stft_loss(pred, wav, pred_subbands, target_subbands, stft_funcs_for_loss)[0] loss_G += stft_loss valid_d_loss += loss_D.item() valid_g_loss += loss_G.item() valid_mel_loss += mel_err count = idx valid_d_loss /= (count + 1) valid_g_loss /= (count + 1) valid_mel_loss /= (count + 1) mb_generator.train() discriminator.train() # log validation # log writer pred_audio = pred[0, 0] target_audio = wav[0] target_mel = imshow_to_buf(mel[0].detach().cpu().numpy()) pred_mel = imshow_to_buf( mel_func(pred[:1, 0])[0].detach().cpu().numpy()) writer.add_image('valid/target_mel', target_mel, global_step=step) writer.add_image('valid/pred_mel', pred_mel, global_step=step) writer.add_scalar('valid/loss_G', valid_g_loss, global_step=step) writer.add_scalar('valid/loss_D', valid_d_loss, global_step=step) writer.add_scalar('valid/mel_err', valid_mel_loss, global_step=step) writer.add_audio('valid/pred_audio', pred_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) writer.add_audio('valid/target_audio', target_audio, sample_rate=settings.SAMPLE_RATE, global_step=step) # console log(f'---- loss_G: {valid_g_loss} / loss_D: {valid_d_loss} / mel loss : {valid_mel_loss} ----' ) # # save checkpoint # is_best = valid_g_loss < best_loss if is_best: best_loss = valid_g_loss save_checkpoint(mb_generator, discriminator, mb_opt, dis_opt, mb_scheduler, dis_scheduler, model_dir, step, valid_g_loss, is_best=is_best) log('----- Finish ! -----')