class TensorboardLogger(Logger): def __init__(self, log_interval=50, validation_interval=200, generate_interval=500, trainer=None, generate_function=None, log_dir='logs'): super().__init__(log_interval, validation_interval, generate_interval, trainer, generate_function) self.writer = SummaryWriter(log_dir) def log_loss(self, current_step): # loss avg_loss = self.accumulated_loss / self.log_interval self.writer.add_scalar('loss', avg_loss, current_step) def validate(self, current_step): avg_loss, avg_accuracy = self.trainer.validate() self.writer.add_scalar('validation/loss', avg_loss, current_step) self.writer.add_scalar('validation/accuracy', avg_accuracy, current_step) def log_audio(self, step): samples = self.generate_function() self.writer.add_audio('audio sample', samples, step, sample_rate=16000) def image_summary(self, tag, images, step): """Log a list of images.""" self.writer.add_images(tag, images, step) def audio_summary(self, tag, sample, step, sr=16000): self.writer.add_audio(tag, sample, sample_rate=sr)
class TensorboardHandler(object): def __init__(self, path="runs"): os.system(f'rm -rf {path}') self.writer = SummaryWriter(path) self._counter = 0 self.models = [] self.losses = [] def update(self, models, losses, epoch=None): if epoch is None: epoch = self._counter apply_method(self, "add_model", models, epoch=epoch) apply_method(self, "add_loss", losses, epoch=epoch) self._counter += 1 def add_model(self, model, epoch): # self.writer.add_graph(model, verbose=True) for n, p in model.named_parameters(): self.writer.add_histogram('model/'+n,p.detach().cpu().numpy(),epoch) def add_loss(self, losses, epoch): for loss in losses: loss_names = list(loss.loss_history[list(loss.loss_history.keys())[0]].keys()) for l in loss_names: # write last losses # pdb.set_trace() last_loss = {k: np.array(v[l]['values'][-1]) if len(v) > 0 else 0. for k, v in loss.loss_history.items()} self.writer.add_scalars('loss_%s/'%l, last_loss, epoch) def add_image(self, name, pic, epoch): self.writer.add_image(name, pic, epoch) def add_audio(self, name, audio, sr): self.writer.add_audio(name, audio, sample_rate=sr)
def validate(model: Model, val_loader: DataLoader, writer: SummaryWriter, iteration: int, waveglow_path: Path = None, fp16: bool = False): torch.manual_seed(42) torch.cuda.manual_seed(42) np.random.seed(iteration % 42) metric_values, success_rates = [], [] model.eval() start = time.time() # choose random mel-spectogram and synthesize audio from it random_idx = np.random.choice(len(val_loader), 1) for i, (text, input_lengths, mel, stop_target) in enumerate(val_loader): text = text.to("cuda", non_blocking=True) input_lengths = input_lengths.to("cuda", non_blocking=True) mel = mel.to("cuda", non_blocking=True) stop_target = stop_target.to("cuda", non_blocking=True) with torch.no_grad(): mel_pred, mel_pred_postnet, stop_predictions, alignment = model( text, input_lengths, mel) mel_loss = F.mse_loss(mel_pred, mel) mel_postnet_loss = F.mse_loss(mel_pred_postnet, mel) stop_loss = F.binary_cross_entropy(stop_predictions, stop_target) loss = mel_loss + mel_postnet_loss + stop_loss loss_value = loss.item() metric_values.append(loss_value) success_rates.append(success_rate(alignment)) if i == random_idx: mel_to_gen = mel_pred_postnet[0] gen_alignment = alignment[0] avg_loss = np.mean(metric_values) avg_success_rate = np.mean(success_rates) writer.add_scalar('loss/validation', avg_loss, iteration) writer.add_scalar('success_rate/validation', avg_success_rate, iteration) writer.add_image("mel_pred/validation", show_figure(mel_to_gen.float().cpu().numpy()), iteration, dataformats='HWC') writer.add_image("alignment/validation", show_figure(gen_alignment.float().cpu().numpy(), origin='lower'), iteration, dataformats='HWC') if waveglow_path: print('start to audio synthesis') writer.add_audio("audio/validation", waveglow_gen(waveglow_path, mel_to_gen[None], fp16=fp16), iteration, sample_rate=22050) end = time.time() print( f"validation {iteration}, loss={loss_value:.3f}, {end - start:.2f} s.") return avg_loss
class TensorLogger(object): # creating file in given logdir ... defaults to ./runs/ def __init__(self, _logdir='./runs/'): if not os.path.exists(_logdir): os.makedirs(_logdir) self.writer = SummaryWriter(log_dir=_logdir) # adding scalar value to tb file def scalar_summary(self, _tag, _value, _step): self.writer.add_scalar(_tag, _value, _step) # adding image value to tb file def image_summary(self, _tag, _image, _step, _format='CHW'): """ default dataformat for image tensor is (3, H, W) can be changed to : (1, H, W) - dataformat = CHW : (H, W, 3) - dataformat HWC : (H, W) - datformat HW """ # self.writer.add_image(_tag, _image, _step, dataformat=_format) # adding matplotlib figure to tb file def figure_summary(self, _tag, _figure, _step): self.writer.add_figure(_tag, _figure, _step) # adding video to tb file def video_summary(self, _tag, _video, _step, _fps=4): """ default torch fps is 4, can be changed also, video tensor should be of format (N, T, C, H, W) values should be between [0,255] for unit8 and [0,1] for float32 """ # default value of video fps is 4 - can be changed self.writer.add_video(_tag, _video, _step, _fps) # adding audio to tb file def audio_summary(self, _tag, _sound, _step, _sampleRate=44100): """ default torch sample rate is 44100, can be changed also, sound tensor should be of format (1, L) values should lie between [-1,1] """ self.writer.add_audio(_tag, _sound, _step, sample_rate=_sampleRate) # adding text to tb file def text_summary(self, _tag, _textString, _step): self.writer.add_text(_tag, _textString, _step) # adding histograms to tb file def histogram_summary(self, _tag, _histogram, _step, _bins='tensorflow'): self.writer.add_histrogram(_tag, _histogram, _step, bins=_bins)
class TensorBoardLogger(Logger): def __init__(self, *args, **kwargs): super().__init__() path = Path(kwargs['log_dir']) path.mkdir(parents=True, exist_ok=True) self.log_dir = path self.writer = SummaryWriter(path) self.track_info = kwargs.get('track_info', None) if self.track_info is not None: self.generate_multitrack = partial(generate_multitrack, **self.track_info) def add_scalar(self, name, value, step): self.writer.add_scalar(name, value, global_step=step) def add_image(self, name, value, step): # value is img_tensor with dim (3, H, W) or (1, H, W) self.writer.add_image(name, value, global_step=step) def add_figure(self, name, fig, step): self.writer.add_figure(name, fig, global_step=step) def add_audio(self, name, value, sample_rate, step): self.writer.add_audio(name, value, global_step=step, sample_rate=sample_rate) def add_histogram(self, name, value, step): self.writer.add_histogram(name, value, global_step=step) def add_pianoroll_img(self, name, pianoroll, step): for i, instrument in enumerate(self.track_info['instruments']): # pianoroll has dim (seq_length, instruments, n_pitches) self.add_image("{}_{}".format(name, instrument), np.expand_dims(pianoroll[:, 0, :].T, 0), step) def add_pianoroll_audio(self, name, pianoroll, step): import librosa from midi2audio import FluidSynth midi_file = str(self.log_dir / "tmp.mid") wav_file = str(self.log_dir / "tmp.wav") multitrack = self.generate_multitrack(pianoroll) multitrack.write(midi_file) FluidSynth().midi_to_audio(midi_file, wav_file) y, sr = librosa.load(wav_file) self.add_audio(name, y, sr, step) def close(self): self.writer.close()
class VisualizerTensorboard: def __init__(self, opts): self.dtype = {} self.iteration = 1 self.writer = SummaryWriter(opts.logs_dir) def register(self, modules): # here modules are assumed to be a dictionary for key in modules: self.dtype[key] = modules[key]['dtype'] def update(self, modules): for key, value in modules: if self.dtype[key] == 'scalar': self.writer.add_scalar(key, value, self.iteration) elif self.dtype[key] == 'scalars': self.writer.add_scalars(key, value, self.iteration) elif self.dtype[key] == 'histogram': self.writer.add_histogram(key, value, self.iteration) elif self.dtype[key] == 'image': self.writer.add_image(key, value, self.iteration) elif self.dtype[key] == 'images': self.writer.add_images(key, value, self.iteration) elif self.dtype[key] == 'figure': self.writer.add_figure(key, value, self.iteration) elif self.dtype[key] == 'video': self.writer.add_video(key, value, self.iteration) elif self.dtype[key] == 'audio': self.writer.add_audio(key, value, self.iteration) elif self.dtype[key] == 'text': self.writer.add_text(key, value, self.iteration) elif self.dtype[key] == 'embedding': self.writer.add_embedding(key, value, self.iteration) elif self.dtype[key] == 'pr_curve': self.writer.pr_curve(key, value['labels'], value['predictions'], self.iteration) elif self.dtype[key] == 'mesh': self.writer.add_audio(key, value, self.iteration) elif self.dtype[key] == 'hparams': self.writer.add_hparams(key, value['hparam_dict'], value['metric_dict'], self.iteration) else: raise Exception( 'Data type not supported, please update the visualizer plugin and rerun !!' ) self.iteration = self.iteration + 1
class SummaryWriter: def __init__(self, logdir, flush_secs=120): self.writer = TensorboardSummaryWriter( log_dir=logdir, purge_step=None, max_queue=10, flush_secs=flush_secs, filename_suffix='') self.global_step = None self.active = True # ------------------------------------------------------------------------ # register add_* and set_* functions in summary module on instantiation # ------------------------------------------------------------------------ this_module = sys.modules[__name__] list_of_names = dir(SummaryWriter) for name in list_of_names: # add functions (without the 'add' prefix) if name.startswith('add_'): setattr(this_module, name[4:], getattr(self, name)) # set functions if name.startswith('set_'): setattr(this_module, name, getattr(self, name)) def set_global_step(self, value): self.global_step = value def set_active(self, value): self.active = value def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_audio( tag, snd_tensor, global_step=global_step, sample_rate=sample_rate, walltime=walltime) def add_custom_scalars(self, layout): if self.active: self.writer.add_custom_scalars(layout) def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'): if self.active: self.writer.add_custom_scalars_marginchart(tags, category=category, title=title) def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'): if self.active: self.writer.add_custom_scalars_multilinechart(tags, category=category, title=title) def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_embedding( mat, metadata=metadata, label_img=label_img, global_step=global_step, tag=tag, metadata_header=metadata_header) def add_figure(self, tag, figure, global_step=None, close=True, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_figure( tag, figure, global_step=global_step, close=close, walltime=walltime) def add_graph(self, model, input_to_model=None, verbose=False): if self.active: self.writer.add_graph(model, input_to_model=input_to_model, verbose=verbose) def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_histogram( tag, values, global_step=global_step, bins=bins, walltime=walltime, max_bins=max_bins) def add_histogram_raw(self, tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_histogram_raw( tag, min=min, max=max, num=num, sum=sum, sum_squares=sum_squares, bucket_limits=bucket_limits, bucket_counts=bucket_counts, global_step=global_step, walltime=walltime) def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_image( tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None, walltime=None, rescale=1, dataformats='CHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_image_with_boxes( tag, img_tensor, box_tensor, global_step=global_step, walltime=walltime, rescale=rescale, dataformats=dataformats) def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_images( tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_mesh( tag, vertices, colors=colors, faces=faces, config_dict=config_dict, global_step=global_step, walltime=walltime) def add_onnx_graph(self, graph): if self.active: self.writer.add_onnx_graph(graph) def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_pr_curve( tag, labels, predictions, global_step=global_step, num_thresholds=num_thresholds, weights=weights, walltime=walltime) def add_pr_curve_raw(self, tag, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, global_step=None, num_thresholds=127, weights=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_pr_curve_raw( tag, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, global_step=global_step, num_thresholds=num_thresholds, weights=weights, walltime=walltime) def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_scalar( tag, scalar_value, global_step=global_step, walltime=walltime) def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_scalars( main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime) def add_text(self, tag, text_string, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_text( tag, text_string, global_step=global_step, walltime=walltime) def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_video( tag, vid_tensor, global_step=global_step, fps=fps, walltime=walltime) def close(self): self.writer.close() def __enter__(self): return self.writer.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): return self.writer.__exit__(exc_type, exc_val, exc_tb)
def train_melgan(args): # args = parse_args() root = Path(args.save_path) load_root = Path(args.load_path) if args.load_path else None root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### # with open(root / "args.yml", "w") as f: # yaml.dump(args, f) with open(root / "args.json", "w", encoding="utf8") as f: json.dump(args.__dict__, f, indent=4, ensure_ascii=False) eventdir = root / "events" eventdir.mkdir(exist_ok=True) writer = SummaryWriter(str(eventdir)) ####################### # Load PyTorch Models # ####################### ratios = [int(w) for w in args.ratios.split()] netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers, ratios=ratios).to(_device) netD = Discriminator( args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor ).to(_device) # fft = Audio2Mel(n_mel_channels=args.n_mel_channels).to(_device) if args.mode == 'default': fft = audio2mel elif args.mode == 'synthesizer': fft = audio2mel_synthesizer elif args.mode == 'mellotron': fft = audio2mel_mellotron else: raise KeyError # print(netG) # print(netD) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) if load_root and load_root.exists(): netG.load_state_dict(torch.load(load_root)) # optG.load_state_dict(torch.load(load_root / "optG.pt")) # netD.load_state_dict(torch.load(load_root / "netD.pt")) # optD.load_state_dict(torch.load(load_root / "optD.pt")) ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path), args.seq_len, sampling_rate=args.sample_rate ) test_set = AudioDataset( Path(args.data_path), # test file args.sample_rate * 4, sampling_rate=args.sample_rate, augment=False, ) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] for i, x_t in enumerate(test_loader): x_t = x_t.to(_device) s_t = fft(x_t).detach() test_voc.append(s_t.to(_device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() oridir = root / "original" oridir.mkdir(exist_ok=True) save_sample(oridir / ("original_{}_{}.wav".format("test", i)), args.sample_rate, audio) writer.add_audio("original/{}/sample_{}.wav".format("test", i), audio, 0, sample_rate=args.sample_rate) if i == args.n_test_samples - 1: break costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 step_begin = args.start_step look_steps = {step_begin + 10, step_begin + 100, step_begin + 1000, step_begin + 10000} steps = step_begin for epoch in range(1, args.epochs + 1): print("\nEpoch {} beginning. Current step: {}".format(epoch, steps)) for iterno, x_t in enumerate(tqdm(train_loader, desc="iter", ncols=100)): # torch.Size([4, 1, 8192]) torch.Size([4, 80, 32]) # 8192 = 32 x 256 x_t = x_t.to(_device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(_device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(_device).detach()) D_real = netD(x_t.to(_device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(_device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() ###################### # Update tensorboard # ###################### costs.append([loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) steps += 1 writer.add_scalar("loss/discriminator", costs[-1][0], steps) writer.add_scalar("loss/generator", costs[-1][1], steps) writer.add_scalar("loss/feature_matching", costs[-1][2], steps) writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) if steps % args.save_interval == 0 or steps in look_steps: st = time.time() with torch.no_grad(): for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() gendir = root / "generated" gendir.mkdir(exist_ok=True) save_sample(gendir / ("generated_step{}_{}.wav".format(steps, i)), args.sample_rate, pred_audio) writer.add_audio( "generated/step{}/sample_{}.wav".format(steps, i), pred_audio, epoch, sample_rate=args.sample_rate, ) ptdir = root / "models" ptdir.mkdir(exist_ok=True) torch.save(netG.state_dict(), ptdir / "step{}_netG.pt".format(steps)) torch.save(optG.state_dict(), ptdir / "step{}_optG.pt".format(steps)) torch.save(netD.state_dict(), ptdir / "step{}_netD.pt".format(steps)) torch.save(optD.state_dict(), ptdir / "step{}_optD.pt".format(steps)) if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), ptdir / "best_step{}_netD.pt".format(steps)) torch.save(netG.state_dict(), ptdir / "best_step{}_netG.pt".format(steps)) # print("\nTook %5.4fs to generate samples" % (time.time() - st)) # print("-" * 100) if steps % args.log_interval == 0 or steps in look_steps: print( "\nEpoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), ) ) costs = [] start = time.time()
class VocTrainer: def __init__(self, paths: Paths, dsp: DSP, config: Dict[str, Any]) -> None: self.paths = paths self.writer = SummaryWriter(log_dir=paths.voc_log, comment='v1') self.dsp = dsp self.config = config self.train_cfg = config['vocoder']['training'] self.loss_func = F.cross_entropy if self.dsp.voc_mode == 'RAW' else discretized_mix_logistic_loss path_top_k = paths.voc_top_k / 'top_k.pkl' if os.path.exists(path_top_k): self.top_k_models = unpickle_binary(path_top_k) # log recent top models for i, (mel_loss, g_wav, m_step, m_name) in enumerate(self.top_k_models, 1): self.writer.add_audio(tag=f'Top_K_Models/generated_top_{i}', snd_tensor=g_wav, global_step=m_step, sample_rate=self.dsp.sample_rate) else: self.top_k_models = [] def train(self, model: WaveRNN, optimizer: Optimizer, train_gta=False) -> None: voc_schedule = self.train_cfg['schedule'] voc_schedule = parse_schedule(voc_schedule) for i, session_params in enumerate(voc_schedule, 1): lr, max_step, bs = session_params if model.get_step() < max_step: train_set, val_set, val_set_samples = get_vocoder_datasets( path=self.paths.data, batch_size=bs, train_gta=train_gta, max_mel_len=self.train_cfg['max_mel_len'], hop_length=self.dsp.hop_length, voc_pad=model.pad, voc_seq_len=self.train_cfg['seq_len'], voc_mode=self.dsp.voc_mode, bits=self.dsp.bits, num_gen_samples=self.train_cfg['num_gen_samples']) session = VocSession(index=i, lr=lr, max_step=max_step, bs=bs, train_set=train_set, val_set=val_set, val_set_samples=val_set_samples) self.train_session(model, optimizer, session, train_gta) def train_session(self, model: WaveRNN, optimizer: Optimizer, session: VocSession, train_gta: bool) -> None: current_step = model.get_step() training_steps = session.max_step - current_step total_iters = len(session.train_set) epochs = training_steps // total_iters + 1 simple_table([(f'Steps ', str(training_steps // 1000) + 'k'), ('Batch Size', session.bs), ('Learning Rate', session.lr), ('Sequence Length', self.train_cfg['seq_len']), ('GTA Training', train_gta)]) for g in optimizer.param_groups: g['lr'] = session.lr loss_avg = Averager() duration_avg = Averager() device = next( model.parameters()).device # use same device as model parameters for e in range(1, epochs + 1): for i, batch in enumerate(session.train_set, 1): start = time.time() model.train() batch = to_device(batch, device=device) x, y = batch['x'], batch['y'] y_hat = model(x, batch['mel']) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = batch['y'].float() y = y.unsqueeze(-1) loss = self.loss_func(y_hat, y) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), self.train_cfg['clip_grad_norm']) optimizer.step() loss_avg.add(loss.item()) step = model.get_step() k = step // 1000 duration_avg.add(time.time() - start) speed = 1. / duration_avg.get() msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \ f'| {speed:#.2} steps/s | Step: {k}k | ' if step % self.train_cfg['gen_samples_every'] == 0: stream(msg + 'generating samples...') gen_result = self.generate_samples(model, session) if gen_result is not None: mel_loss, gen_wav = gen_result self.writer.add_scalar('Loss/generated_mel_l1', mel_loss, model.get_step()) self.track_top_models(mel_loss, gen_wav, model) if step % self.train_cfg['checkpoint_every'] == 0: save_checkpoint(model=model, optim=optimizer, config=self.config, path=self.paths.voc_checkpoints / f'wavernn_step{k}k.pt') self.writer.add_scalar('Loss/train', loss, model.get_step()) self.writer.add_scalar('Params/batch_size', session.bs, model.get_step()) self.writer.add_scalar('Params/learning_rate', session.lr, model.get_step()) stream(msg) val_loss = self.evaluate(model, session.val_set) self.writer.add_scalar('Loss/val', val_loss, model.get_step()) save_checkpoint(model=model, optim=optimizer, config=self.config, path=self.paths.voc_checkpoints / 'latest_model.pt') loss_avg.reset() duration_avg.reset() print(' ') def evaluate(self, model: WaveRNN, val_set: Dataset) -> float: model.eval() val_loss = 0 device = next(model.parameters()).device for i, batch in enumerate(val_set, 1): batch = to_device(batch, device=device) x, y, m = batch['x'], batch['y'], batch['mel'] with torch.no_grad(): y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) loss = self.loss_func(y_hat, y) val_loss += loss.item() return val_loss / len(val_set) @ignore_exception def generate_samples(self, model: WaveRNN, session: VocSession) -> Tuple[float, list]: """ Generates audio samples to cherry-pick models. To evaluate audio quality we calculate the l1 distance between mels of predictions and targets. """ model.eval() mel_losses = [] gen_wavs = [] device = next(model.parameters()).device for i, sample in enumerate(session.val_set_samples, 1): m, x = sample['mel'], sample['x'] if i > self.train_cfg['num_gen_samples']: break x = x[0].numpy() bits = 16 if self.dsp.voc_mode == 'MOL' else self.dsp.bits if self.dsp.mu_law and self.dsp.voc_mode != 'MOL': x = DSP.decode_mu_law(x, 2**bits, from_labels=True) else: x = DSP.label_2_float(x, bits) gen_wav = model.generate(mels=m, batched=self.train_cfg['gen_batched'], target=self.train_cfg['target'], overlap=self.train_cfg['overlap'], mu_law=self.dsp.mu_law, silent=True) gen_wavs.append(gen_wav) y_mel = self.dsp.wav_to_mel(x.squeeze(), normalize=False) y_mel = torch.tensor(y_mel).to(device) y_hat_mel = self.dsp.wav_to_mel(gen_wav, normalize=False) y_hat_mel = torch.tensor(y_hat_mel).to(device) loss = F.l1_loss(y_hat_mel, y_mel) mel_losses.append(loss.item()) self.writer.add_audio(tag=f'Validation_Samples/target_{i}', snd_tensor=x, global_step=model.step, sample_rate=self.dsp.sample_rate) self.writer.add_audio(tag=f'Validation_Samples/generated_{i}', snd_tensor=gen_wav, global_step=model.step, sample_rate=self.dsp.sample_rate) return sum(mel_losses) / len(mel_losses), gen_wavs[0] def track_top_models(self, mel_loss, gen_wav, model): """ Keeps track of top k models and saves them according to their current rank """ for j, (l, g, m, m_n) in enumerate(self.top_k_models): print(f'{j} {l} {m} {m_n}') if len(self.top_k_models) < self.train_cfg[ 'keep_top_k'] or mel_loss < self.top_k_models[-1][0]: m_step = model.get_step() model_name = f'model_loss{mel_loss:#0.5}_step{m_step}_weights.pyt' self.top_k_models.append( (mel_loss, gen_wav, model.get_step(), model_name)) self.top_k_models.sort(key=lambda t: t[0]) self.top_k_models = self.top_k_models[:self. train_cfg['keep_top_k']] model.save(self.paths.voc_top_k / model_name) all_models = get_files(self.paths.voc_top_k, extension='pyt') top_k_names = {m[-1] for m in self.top_k_models} for model_file in all_models: if model_file.name not in top_k_names: print(f'removing {model_file}') os.remove(model_file) pickle_binary(self.top_k_models, self.paths.voc_top_k / 'top_k.pkl') for i, (mel_loss, g_wav, m_step, m_name) in enumerate(self.top_k_models, 1): self.writer.add_audio(tag=f'Top_K_Models/generated_top_{i}', snd_tensor=g_wav, global_step=m_step, sample_rate=self.dsp.sample_rate)
def train(rank, a, h): if h.num_gpus > 1: init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) torch.cuda.manual_seed(h.seed) device = torch.device('cuda:{:d}'.format(rank)) generator = Generator(h).to(device) mpd = MultiPeriodDiscriminator( h["discriminator_periods"] if "discriminator_periods" in h.keys() else None).to(device) msd = MultiScaleDiscriminator().to(device) if rank == 0: print(generator) os.makedirs(a.checkpoint_path, exist_ok=True) print("checkpoints directory : ", a.checkpoint_path) if os.path.isdir(a.checkpoint_path): cp_g = scan_checkpoint(a.checkpoint_path, 'g_') cp_do = scan_checkpoint(a.checkpoint_path, 'do_') steps = 0 if cp_g is not None: state_dict_g = load_checkpoint(cp_g, device) gsd = generator.state_dict() gsd.update({ k: v for k, v in state_dict_g['generator'].items() if k in gsd and state_dict_g['generator'][k].shape == gsd[k].shape }) missing_keys = { k: v for k, v in state_dict_g['generator'].items() if not (k in gsd and state_dict_g['generator'][k].shape == gsd[k].shape) }.keys() generator.load_state_dict(gsd) del gsd, state_dict_g if cp_do is None or len(missing_keys) or a.from_zero: state_dict_do = None last_epoch = -1 else: state_dict_do = load_checkpoint(cp_do, device) mpd.load_state_dict(state_dict_do['mpd']) del state_dict_do['mpd'] msd.load_state_dict(state_dict_do['msd']) del state_dict_do['msd'] steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if h.num_gpus > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) del state_dict_do scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) training_filelist, validation_filelist = get_dataset_filelist( a, h.segment_size, h.sampling_rate) trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, trim_non_voiced=a.trim_non_voiced) STFT = STFT_Class(h.sampling_rate, h.num_mels, h.n_fft, h.win_size, h.hop_size, h.fmin, h.fmax) train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, sampler=train_sampler, batch_size=h.batch_size, pin_memory=True, drop_last=True) assert len(train_loader), 'No audio files in dataset!' if rank == 0: validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, trim_non_voiced=a.trim_non_voiced) validation_loader = DataLoader(validset, num_workers=h.num_workers, shuffle=False, sampler=None, batch_size=1, pin_memory=True, drop_last=True) sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'), max_queue=10000) sw.logged_gt_plots = False if h.num_gpus > 1: import gc gc.collect() torch.cuda.empty_cache() generator.train() mpd.train() msd.train() for epoch in range(max(0, last_epoch), a.training_epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if h.num_gpus > 1: train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): if rank == 0: start_b = time.time() x, y, _, y_mel = batch x = torch.autograd.Variable(x.to(device, non_blocking=True)) y = torch.autograd.Variable(y.to(device, non_blocking=True)) y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) y = y.unsqueeze(1) y_g_hat = generator(x) y_g_hat_mel = STFT.get_mel(y_g_hat.squeeze(1)) optim_d.zero_grad() # MPD y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( y_df_hat_r, y_df_hat_g) # MSD y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L1 Mel-Spectrogram Loss loss_mel = F.l1_loss(y_mel, y_g_hat_mel) y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel * 45 loss_gen_all.backward() optim_g.step() if rank == 0: torch.set_grad_enabled(False) # STDOUT logging if steps % a.stdout_interval == 0: print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}' .format(steps, loss_gen_all, loss_mel.item(), time.time() - start_b)) # checkpointing if steps % a.checkpoint_interval == 0 and steps != 0: checkpoint_path = "{}/g_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'generator': (generator.module if h.num_gpus > 1 else generator).state_dict() }) checkpoint_path = "{}/do_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch }) del_old_checkpoints(a.checkpoint_path, 'g_', a.n_models_to_keep) del_old_checkpoints(a.checkpoint_path, 'do_', a.n_models_to_keep) # Tensorboard summary logging if steps % a.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", loss_mel.item(), steps) # Validation if steps % a.validation_interval == 0: # and steps != 0: print("Validating...") n_audios_to_plot = 6 generator.eval() torch.cuda.empty_cache() val_err_tot = 0 for j, batch in tqdm(enumerate(validation_loader), total=len(validation_loader)): x, y, _, y_mel = batch y_g_hat = generator(x.to(device)) y_hat_spec = STFT.get_mel(y_g_hat.squeeze(1)) val_err_tot += F.l1_loss(y_mel, y_hat_spec.to(y_mel)).item() if j < n_audios_to_plot and not sw.logged_gt_plots: sw.add_audio(f'gt/y_{j}', y[0], steps, h.sampling_rate) sw.add_figure(f'spec_{j:02}/gt_spec', plot_spectrogram(y_mel[0]), steps) if j < n_audios_to_plot: sw.add_audio(f'generated/y_hat_{j}', y_g_hat[0], steps, h.sampling_rate) sw.add_figure( f'spec_{j:02}/pred_spec', plot_spectrogram( y_hat_spec.squeeze(0).cpu().numpy()), steps) if j > 64: # I am NOT patient enough to complete an entire validation cycle with over 1536 files. break sw.logged_gt_plots = True val_err = val_err_tot / (j + 1) sw.add_scalar("validation/mel_spec_error", val_err, steps) generator.train() print(f"Done. Val_loss = {val_err}") torch.set_grad_enabled(True) steps += 1 scheduler_g.step() scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))
def main(args): #torch.backends.cudnn.benchmark=True # This makes dilated conv much faster for CuDNN 7.5 # MODEL num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \ [args.features*2**i for i in range(0, args.levels)] target_outputs = int(args.output_size * args.sr) model = Waveunet(args.channels, num_features, args.channels, args.instruments, kernel_size=args.kernel_size, target_output_size=target_outputs, depth=args.depth, strides=args.strides, conv_type=args.conv_type, res=args.res, separate=args.separate) if args.cuda: model = utils.DataParallel(model) print("move model to gpu") model.cuda() print('model: ', model) print('parameter count: ', str(sum(p.numel() for p in model.parameters()))) writer = SummaryWriter(args.log_dir) ### DATASET musdb = get_musdb_folds(args.dataset_dir) # If not data augmentation, at least crop targets to fit model output shape crop_func = partial(crop, shapes=model.shapes) # Data augmentation function for training augment_func = partial(random_amplify, shapes=model.shapes, min=0.7, max=1.0) train_data = SeparationDataset(musdb, "train", args.instruments, args.sr, args.channels, model.shapes, True, args.hdf_dir, audio_transform=augment_func) val_data = SeparationDataset(musdb, "val", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func) test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, model.shapes, False, args.hdf_dir, audio_transform=crop_func) dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn) ##### TRAINING #### # Set up the loss function if args.loss == "L1": criterion = nn.L1Loss() elif args.loss == "L2": criterion = nn.MSELoss() else: raise NotImplementedError("Couldn't find this loss!") # Set up optimiser optimizer = Adam(params=model.parameters(), lr=args.lr) # Set up training state dict that will also be saved into checkpoints state = {"step" : 0, "worse_epochs" : 0, "epochs" : 0, "best_loss" : np.Inf} # LOAD MODEL CHECKPOINT IF DESIRED if args.load_model is not None: print("Continuing training full model from checkpoint " + str(args.load_model)) state = utils.load_model(model, optimizer, args.load_model) print('TRAINING START') while state["worse_epochs"] < args.patience: print("Training one epoch from iteration " + str(state["step"])) avg_time = 0. model.train() with tqdm(total=len(train_data) // args.batch_size) as pbar: np.random.seed() for example_num, (x, targets) in enumerate(dataloader): if args.cuda: x = x.cuda() for k in list(targets.keys()): targets[k] = targets[k].cuda() t = time.time() # Set LR for this iteration utils.set_cyclic_lr(optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr) writer.add_scalar("lr", utils.get_lr(optimizer), state["step"]) # Compute loss for each instrument/model optimizer.zero_grad() outputs, avg_loss = utils.compute_loss(model, x, targets, criterion, compute_grad=True) optimizer.step() state["step"] += 1 t = time.time() - t avg_time += (1. / float(example_num + 1)) * (t - avg_time) writer.add_scalar("train_loss", avg_loss, state["step"]) if example_num % args.example_freq == 0: input_centre = torch.mean(x[0, :, model.shapes["output_start_frame"]:model.shapes["output_end_frame"]], 0) # Stereo not supported for logs yet writer.add_audio("input", input_centre, state["step"], sample_rate=args.sr) for inst in outputs.keys(): writer.add_audio(inst + "_pred", torch.mean(outputs[inst][0], 0), state["step"], sample_rate=args.sr) writer.add_audio(inst + "_target", torch.mean(targets[inst][0], 0), state["step"], sample_rate=args.sr) pbar.update(1) # VALIDATE val_loss = validate(args, model, criterion, val_data) print("VALIDATION FINISHED: LOSS: " + str(val_loss)) writer.add_scalar("val_loss", val_loss, state["step"]) # EARLY STOPPING CHECK checkpoint_path = os.path.join(args.checkpoint_dir, "checkpoint_" + str(state["step"])) if val_loss >= state["best_loss"]: state["worse_epochs"] += 1 else: print("MODEL IMPROVED ON VALIDATION SET!") state["worse_epochs"] = 0 state["best_loss"] = val_loss state["best_checkpoint"] = checkpoint_path # CHECKPOINT print("Saving model...") utils.save_model(model, optimizer, state, checkpoint_path) state["epochs"] += 1 #### TESTING #### # Test loss print("TESTING") # Load best model based on validation loss state = utils.load_model(model, None, state["best_checkpoint"]) test_loss = validate(args, model, criterion, test_data) print("TEST FINISHED: LOSS: " + str(test_loss)) writer.add_scalar("test_loss", test_loss, state["step"]) # Mir_eval metrics test_metrics = evaluate(args, musdb["test"], model, args.instruments) # Dump all metrics results into pickle file for later analysis if needed with open(os.path.join(args.checkpoint_dir, "results.pkl"), "wb") as f: pickle.dump(test_metrics, f) # Write most important metrics into Tensorboard log avg_SDRs = {inst : np.mean([np.nanmean(song[inst]["SDR"]) for song in test_metrics]) for inst in args.instruments} avg_SIRs = {inst : np.mean([np.nanmean(song[inst]["SIR"]) for song in test_metrics]) for inst in args.instruments} for inst in args.instruments: writer.add_scalar("test_SDR_" + inst, avg_SDRs[inst], state["step"]) writer.add_scalar("test_SIR_" + inst, avg_SIRs[inst], state["step"]) overall_SDR = np.mean([v for v in avg_SDRs.values()]) writer.add_scalar("test_SDR", overall_SDR) print("SDR: " + str(overall_SDR)) writer.close()
class TacoTrainer: def __init__(self, paths: Paths) -> None: self.paths = paths self.writer = SummaryWriter(log_dir=paths.tts_log, comment='v1') def train(self, model: Tacotron, optimizer: Optimizer) -> None: for i, session_params in enumerate(hp.tts_schedule, 1): r, lr, max_step, bs = session_params if model.get_step() < max_step: train_set, val_set = get_tts_datasets(path=self.paths.data, batch_size=bs, r=r, model_type='tacotron') session = TTSSession(index=i, r=r, lr=lr, max_step=max_step, bs=bs, train_set=train_set, val_set=val_set) self.train_session(model, optimizer, session) def train_session(self, model: Tacotron, optimizer: Optimizer, session: TTSSession) -> None: current_step = model.get_step() training_steps = session.max_step - current_step total_iters = len(session.train_set) epochs = training_steps // total_iters + 1 model.r = session.r simple_table([(f'Steps with r={session.r}', str(training_steps // 1000) + 'k Steps'), ('Batch Size', session.bs), ('Learning Rate', session.lr), ('Outputs/Step (r)', model.r)]) for g in optimizer.param_groups: g['lr'] = session.lr loss_avg = Averager() duration_avg = Averager() device = next( model.parameters()).device # use same device as model parameters for e in range(1, epochs + 1): for i, (x, m, ids, x_lens, mel_lens) in enumerate(session.train_set, 1): start = time.time() model.train() x, m = x.to(device), m.to(device) m1_hat, m2_hat, attention = model(x, m) m1_loss = F.l1_loss(m1_hat, m) m2_loss = F.l1_loss(m2_hat, m) loss = m1_loss + m2_loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) optimizer.step() loss_avg.add(loss.item()) step = model.get_step() k = step // 1000 duration_avg.add(time.time() - start) speed = 1. / duration_avg.get() msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \ f'| {speed:#.2} steps/s | Step: {k}k | ' if step % hp.tts_checkpoint_every == 0: ckpt_name = f'taco_step{k}K' save_checkpoint('tts', self.paths, model, optimizer, name=ckpt_name, is_silent=True) if step % hp.tts_plot_every == 0: self.generate_plots(model, session) _, att_score = attention_score(attention, mel_lens) att_score = torch.mean(att_score) self.writer.add_scalar('Attention_Score/train', att_score, model.get_step()) self.writer.add_scalar('Loss/train', loss, model.get_step()) self.writer.add_scalar('Params/reduction_factor', session.r, model.get_step()) self.writer.add_scalar('Params/batch_size', session.bs, model.get_step()) self.writer.add_scalar('Params/learning_rate', session.lr, model.get_step()) stream(msg) val_loss, val_att_score = self.evaluate(model, session.val_set) self.writer.add_scalar('Loss/val', val_loss, model.get_step()) self.writer.add_scalar('Attention_Score/val', val_att_score, model.get_step()) save_checkpoint('tts', self.paths, model, optimizer, is_silent=True) loss_avg.reset() duration_avg.reset() print(' ') def evaluate(self, model: Tacotron, val_set: Dataset) -> Tuple[float, float]: model.eval() val_loss = 0 val_att_score = 0 device = next(model.parameters()).device for i, (x, m, ids, x_lens, mel_lens) in enumerate(val_set, 1): x, m = x.to(device), m.to(device) with torch.no_grad(): m1_hat, m2_hat, attention = model(x, m) m1_loss = F.l1_loss(m1_hat, m) m2_loss = F.l1_loss(m2_hat, m) val_loss += m1_loss.item() + m2_loss.item() _, att_score = attention_score(attention, mel_lens) val_att_score += torch.mean(att_score).item() return val_loss / len(val_set), val_att_score / len(val_set) @ignore_exception def generate_plots(self, model: Tacotron, session: TTSSession) -> None: model.eval() device = next(model.parameters()).device x, m, ids, x_lens, m_lens = session.val_sample x, m = x.to(device), m.to(device) m1_hat, m2_hat, att = model(x, m) att = np_now(att)[0] m1_hat = np_now(m1_hat)[0, :600, :] m2_hat = np_now(m2_hat)[0, :600, :] m = np_now(m)[0, :600, :] att_fig = plot_attention(att) m1_hat_fig = plot_mel(m1_hat) m2_hat_fig = plot_mel(m2_hat) m_fig = plot_mel(m) self.writer.add_figure('Ground_Truth_Aligned/attention', att_fig, model.step) self.writer.add_figure('Ground_Truth_Aligned/target', m_fig, model.step) self.writer.add_figure('Ground_Truth_Aligned/linear', m1_hat_fig, model.step) self.writer.add_figure('Ground_Truth_Aligned/postnet', m2_hat_fig, model.step) m2_hat_wav = reconstruct_waveform(m2_hat) target_wav = reconstruct_waveform(m) self.writer.add_audio(tag='Ground_Truth_Aligned/target_wav', snd_tensor=target_wav, global_step=model.step, sample_rate=hp.sample_rate) self.writer.add_audio(tag='Ground_Truth_Aligned/postnet_wav', snd_tensor=m2_hat_wav, global_step=model.step, sample_rate=hp.sample_rate) m1_hat, m2_hat, att = model.generate(x[0].tolist(), steps=m_lens[0] + 20) att_fig = plot_attention(att) m1_hat_fig = plot_mel(m1_hat) m2_hat_fig = plot_mel(m2_hat) self.writer.add_figure('Generated/attention', att_fig, model.step) self.writer.add_figure('Generated/target', m_fig, model.step) self.writer.add_figure('Generated/linear', m1_hat_fig, model.step) self.writer.add_figure('Generated/postnet', m2_hat_fig, model.step) m2_hat_wav = reconstruct_waveform(m2_hat) self.writer.add_audio(tag='Generated/target_wav', snd_tensor=target_wav, global_step=model.step, sample_rate=hp.sample_rate) self.writer.add_audio(tag='Generated/postnet_wav', snd_tensor=m2_hat_wav, global_step=model.step, sample_rate=hp.sample_rate)
class ForwardTrainer: def __init__(self, paths: Paths) -> None: self.paths = paths self.writer = SummaryWriter(log_dir=paths.forward_log, comment='v1') self.l1_loss = MaskedL1() def train(self, model: ForwardTacotron, optimizer: Optimizer) -> None: for i, session_params in enumerate(hp.forward_schedule, 1): lr, max_step, bs = session_params if model.get_step() < max_step: train_set, val_set = get_tts_datasets( path=self.paths.data, batch_size=bs, r=1, model_type='forward') session = TTSSession( index=i, r=1, lr=lr, max_step=max_step, bs=bs, train_set=train_set, val_set=val_set) self.train_session(model, optimizer, session) def train_session(self, model: ForwardTacotron, optimizer: Optimizer, session: TTSSession) -> None: current_step = model.get_step() training_steps = session.max_step - current_step total_iters = len(session.train_set) epochs = training_steps // total_iters + 1 simple_table([(f'Steps', str(training_steps // 1000) + 'k Steps'), ('Batch Size', session.bs), ('Learning Rate', session.lr)]) for g in optimizer.param_groups: g['lr'] = session.lr m_loss_avg = Averager() dur_loss_avg = Averager() duration_avg = Averager() pitch_loss_avg = Averager() device = next(model.parameters()).device # use same device as model parameters for e in range(1, epochs + 1): for i, (x, m, ids, x_lens, mel_lens, dur, pitch, puncts) in enumerate( session.train_set, 1 ): start = time.time() model.train() x, m, dur, x_lens, mel_lens, pitch, puncts = ( x.to(device), m.to(device), dur.to(device), x_lens.to(device), mel_lens.to(device), pitch.to(device), puncts.to(device), ) # print("*" * 20) # print(x) # print("*" * 20) m1_hat, m2_hat, dur_hat, pitch_hat = model( x, m, dur, mel_lens, pitch, puncts ) m1_loss = self.l1_loss(m1_hat, m, mel_lens) m2_loss = self.l1_loss(m2_hat, m, mel_lens) dur_loss = self.l1_loss(dur_hat.unsqueeze(1), dur.unsqueeze(1), x_lens) pitch_loss = self.l1_loss(pitch_hat, pitch.unsqueeze(1), x_lens) loss = m1_loss + m2_loss + 0.3 * dur_loss + 0.1 * pitch_loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm) optimizer.step() m_loss_avg.add(m1_loss.item() + m2_loss.item()) dur_loss_avg.add(dur_loss.item()) step = model.get_step() k = step // 1000 duration_avg.add(time.time() - start) pitch_loss_avg.add(pitch_loss.item()) speed = 1. / duration_avg.get() msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Mel Loss: {m_loss_avg.get():#.4} ' \ f'| Dur Loss: {dur_loss_avg.get():#.4} | Pitch Loss: {pitch_loss_avg.get():#.4} ' \ f'| {speed:#.2} steps/s | Step: {k}k | ' if step % hp.forward_checkpoint_every == 0: ckpt_name = f'forward_step{k}K' save_checkpoint('forward', self.paths, model, optimizer, name=ckpt_name, is_silent=True) if step % hp.forward_plot_every == 0: self.generate_plots(model, session) self.writer.add_scalar('Mel_Loss/train', m1_loss + m2_loss, model.get_step()) self.writer.add_scalar('Pitch_Loss/train', pitch_loss, model.get_step()) self.writer.add_scalar('Duration_Loss/train', dur_loss, model.get_step()) self.writer.add_scalar('Params/batch_size', session.bs, model.get_step()) self.writer.add_scalar('Params/learning_rate', session.lr, model.get_step()) stream(msg) m_val_loss, dur_val_loss, pitch_val_loss = self.evaluate(model, session.val_set) self.writer.add_scalar('Mel_Loss/val', m_val_loss, model.get_step()) self.writer.add_scalar('Duration_Loss/val', dur_val_loss, model.get_step()) self.writer.add_scalar('Pitch_Loss/val', pitch_val_loss, model.get_step()) save_checkpoint('forward', self.paths, model, optimizer, is_silent=True) m_loss_avg.reset() duration_avg.reset() pitch_loss_avg.reset() print(' ') def evaluate(self, model: ForwardTacotron, val_set: Dataset) -> Tuple[float, float,float]: model.eval() m_val_loss = 0 dur_val_loss = 0 pitch_val_loss = 0 device = next(model.parameters()).device for i, (x, m, ids, x_lens, mel_lens, dur, pitch, puncts) in enumerate( val_set, 1 ): x, m, dur, x_lens, mel_lens, pitch, puncts = ( x.to(device), m.to(device), dur.to(device), x_lens.to(device), mel_lens.to(device), pitch.to(device), puncts.to(device), ) with torch.no_grad(): m1_hat, m2_hat, dur_hat, pitch_hat = model( x, m, dur, mel_lens, pitch, puncts ) m1_loss = self.l1_loss(m1_hat, m, mel_lens) m2_loss = self.l1_loss(m2_hat, m, mel_lens) dur_loss = self.l1_loss(dur_hat.unsqueeze(1), dur.unsqueeze(1), x_lens) pitch_val_loss += self.l1_loss(pitch_hat, pitch.unsqueeze(1), x_lens) m_val_loss += m1_loss.item() + m2_loss.item() dur_val_loss += dur_loss.item() m_val_loss /= len(val_set) dur_val_loss /= len(val_set) pitch_val_loss /= len(val_set) return m_val_loss, dur_val_loss, pitch_val_loss @ignore_exception def generate_plots(self, model: ForwardTacotron, session: TTSSession) -> None: model.eval() device = next(model.parameters()).device x, m, ids, x_lens, mel_lens, dur, pitch, puncts = session.val_sample x, m, dur, mel_lens, pitch, puncts = ( x.to(device), m.to(device), dur.to(device), mel_lens.to(device), pitch.to(device), puncts.to(device), ) m1_hat, m2_hat, dur_hat, pitch_hat = model(x, m, dur, mel_lens, pitch, puncts) m1_hat = np_now(m1_hat)[0, :600, :] m2_hat = np_now(m2_hat)[0, :600, :] m = np_now(m)[0, :600, :] m1_hat_fig = plot_mel(m1_hat) m2_hat_fig = plot_mel(m2_hat) m_fig = plot_mel(m) pitch_fig = plot_pitch(np_now(pitch[0])) pitch_gta_fig = plot_pitch(np_now(pitch_hat.squeeze()[0])) self.writer.add_figure('Pitch/target', pitch_fig, model.step) self.writer.add_figure('Pitch/ground_truth_aligned', pitch_gta_fig, model.step) self.writer.add_figure('Ground_Truth_Aligned/target', m_fig, model.step) self.writer.add_figure('Ground_Truth_Aligned/linear', m1_hat_fig, model.step) self.writer.add_figure('Ground_Truth_Aligned/postnet', m2_hat_fig, model.step) m2_hat_wav = reconstruct_waveform(m2_hat) target_wav = reconstruct_waveform(m) self.writer.add_audio( tag='Ground_Truth_Aligned/target_wav', snd_tensor=target_wav, global_step=model.step, sample_rate=hp.sample_rate) self.writer.add_audio( tag='Ground_Truth_Aligned/postnet_wav', snd_tensor=m2_hat_wav, global_step=model.step, sample_rate=hp.sample_rate) m1_hat, m2_hat, dur_hat, pitch_hat = model.generate( x[0, : x_lens[0]].tolist(), puncts[0, : x_lens[0]].tolist() ) m1_hat_fig = plot_mel(m1_hat) m2_hat_fig = plot_mel(m2_hat) pitch_gen_fig = plot_pitch(np_now(pitch_hat.squeeze())) self.writer.add_figure('Pitch/generated', pitch_gen_fig, model.step) self.writer.add_figure('Generated/target', m_fig, model.step) self.writer.add_figure('Generated/linear', m1_hat_fig, model.step) self.writer.add_figure('Generated/postnet', m2_hat_fig, model.step) m2_hat_wav = reconstruct_waveform(m2_hat) self.writer.add_audio( tag='Generated/target_wav', snd_tensor=target_wav, global_step=model.step, sample_rate=hp.sample_rate) self.writer.add_audio( tag='Generated/postnet_wav', snd_tensor=m2_hat_wav, global_step=model.step, sample_rate=hp.sample_rate)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(STYLER()).to(device) print("Model Has Been Defined") # Parameters num_param = utils.get_param_num(model) text_encoder = utils.get_param_num( model.module.style_modeling.style_encoder.text_encoder) audio_encoder = utils.get_param_num( model.module.style_modeling.style_encoder.audio_encoder) predictors = utils.get_param_num(model.module.style_modeling.duration_predictor)\ + utils.get_param_num(model.module.style_modeling.pitch_predictor)\ + utils.get_param_num(model.module.style_modeling.energy_predictor) decoder = utils.get_param_num(model.module.decoder) print('Number of Model Parameters :', num_param) print('Number of Text Encoder Parameters :', text_encoder) print('Number of Audio Encoder Parameters :', audio_encoder) print('Number of Predictor Parameters :', predictors) print('Number of Decoder Parameters :', decoder) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = STYLERLoss().to(device) DATLoss = DomainAdversarialTrainingLoss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path()) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder vocoder = utils.get_vocoder() # Init logger log_path = hp.log_path() if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path() if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i*hp.batch_size + j + args.restore_step + \ epoch*len(loader)*hp.batch_size + 1 # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) mel_aug = torch.from_numpy( data_of_batch["mel_aug"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) f0_norm = torch.from_numpy( data_of_batch["f0_norm"]).float().to(device) f0_norm_aug = torch.from_numpy( data_of_batch["f0_norm_aug"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) energy_input = torch.from_numpy( data_of_batch["energy_input"]).float().to(device) energy_input_aug = torch.from_numpy( data_of_batch["energy_input_aug"]).float().to(device) speaker_embed = torch.from_numpy( data_of_batch["speaker_embed"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) # Forward mel_outputs, mel_postnet_outputs, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _, aug_posteriors = model( text, mel_target, mel_aug, f0_norm, energy_input, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len, speaker_embed=speaker_embed) # Cal Loss Clean mel_output, mel_postnet_output = mel_outputs[ 0], mel_postnet_outputs[0] mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss, classifier_loss_a = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask, src_len, mel_len,\ aug_posteriors, torch.zeros(mel_target.size(0)).long().to(device)) # Cal Loss Noisy mel_output_noisy, mel_postnet_output_noisy = mel_outputs[ 1], mel_postnet_outputs[1] mel_noisy_loss, mel_postnet_noisy_loss = Loss.cal_mel_loss( mel_output_noisy, mel_postnet_output_noisy, mel_aug, ~mel_mask) # Forward DAT enc_cat = model.module.style_modeling.style_encoder.encoder_input_cat( mel_aug, f0_norm_aug, energy_input_aug, mel_aug) duration_encoding, pitch_encoding, energy_encoding, _ = model.module.style_modeling.style_encoder.audio_encoder( enc_cat, mel_len, src_len, mask=None) aug_posterior_d = model.module.style_modeling.augmentation_classifier_d( duration_encoding) aug_posterior_p = model.module.style_modeling.augmentation_classifier_p( pitch_encoding) aug_posterior_e = model.module.style_modeling.augmentation_classifier_e( energy_encoding) # Cal Loss DAT classifier_loss_a_dat = DATLoss( (aug_posterior_d, aug_posterior_p, aug_posterior_e), torch.ones(mel_target.size(0)).long().to(device)) # Total loss total_loss = mel_loss + mel_postnet_loss + mel_noisy_loss + mel_postnet_noisy_loss + d_loss + f_loss + e_loss\ + hp.dat_weight*(classifier_loss_a + classifier_loss_a_dat) # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() m_n_l = mel_noisy_loss.item() m_p_n_l = mel_postnet_noisy_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() cl_a = classifier_loss_a.item() cl_a_dat = classifier_loss_a_dat.item() # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step == 1 or current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/mel_noisy_loss', m_n_l, current_step) train_logger.add_scalar('Loss/mel_postnet_noisy_loss', m_p_n_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) train_logger.add_scalar('Loss/dat_clean_loss', cl_a, current_step) train_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step == 1 or current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_aug_torch = mel_aug[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_aug = mel_aug[0, :length].detach().cpu().transpose( 0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel_noisy_torch = mel_output_noisy[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_noisy = mel_output_noisy[ 0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet_noisy_torch = mel_postnet_output_noisy[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) mel_postnet_noisy = mel_postnet_output_noisy[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join( # synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "c"))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join( # synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "c"))) # Audio.tools.inv_mel_spec(mel_noisy, os.path.join( # synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "n"))) # Audio.tools.inv_mel_spec(mel_postnet_noisy, os.path.join( # synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "n"))) wav_mel = utils.vocoder_infer( mel_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_{}.wav'.format(current_step, "c", hp.vocoder))) wav_mel_postnet = utils.vocoder_infer( mel_postnet_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_postnet_{}.wav'.format( current_step, "c", hp.vocoder))) wav_ground_truth = utils.vocoder_infer( mel_target_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_ground-truth_{}.wav'.format( current_step, "c", hp.vocoder))) wav_mel_noisy = utils.vocoder_infer( mel_noisy_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_{}.wav'.format(current_step, "n", hp.vocoder))) wav_mel_postnet_noisy = utils.vocoder_infer( mel_postnet_noisy_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_postnet_{}.wav'.format( current_step, "n", hp.vocoder))) wav_aug = utils.vocoder_infer( mel_aug_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_ground-truth_{}.wav'.format( current_step, "n", hp.vocoder))) # Model duration prediction log_duration_output = log_duration_output[ 0, :src_len[0].item()].detach().cpu() # [seg_len] log_duration_output = torch.clamp(torch.round( torch.exp(log_duration_output) - hp.log_offset), min=0).int() model_duration = utils.get_alignment_2D( log_duration_output).T # [seg_len, mel_len] model_duration = utils.plot_alignment([model_duration]) # Model mel prediction f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() mel_predicted = utils.plot_data( [(mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy)], [ 'Synthetized Spectrogram Clean', 'Ground-Truth Spectrogram' ], filename=os.path.join( synth_path, 'step_{}_{}.png'.format(current_step, "c"))) mel_noisy_predicted = utils.plot_data( [(mel_postnet_noisy.numpy(), f0_output, energy_output), (mel_aug.numpy(), f0, energy)], ['Synthetized Spectrogram Noisy', 'Aug Spectrogram'], filename=os.path.join( synth_path, 'step_{}_{}.png'.format(current_step, "n"))) # Normalize audio for tensorboard logger. See https://github.com/lanpa/tensorboardX/issues/511#issuecomment-537600045 wav_ground_truth = wav_ground_truth / max(wav_ground_truth) wav_mel = wav_mel / max(wav_mel) wav_mel_postnet = wav_mel_postnet / max(wav_mel_postnet) wav_aug = wav_aug / max(wav_aug) wav_mel_noisy = wav_mel_noisy / max(wav_mel_noisy) wav_mel_postnet_noisy = wav_mel_postnet_noisy / max( wav_mel_postnet_noisy) train_logger.add_image("model_duration", model_duration, current_step, dataformats='HWC') train_logger.add_image("mel_predicted/Clean", mel_predicted, current_step, dataformats='HWC') train_logger.add_image("mel_predicted/Noisy", mel_noisy_predicted, current_step, dataformats='HWC') train_logger.add_audio("Clean/wav_ground_truth", wav_ground_truth, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Clean/wav_mel", wav_mel, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Clean/wav_mel_postnet", wav_mel_postnet, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_aug", wav_aug, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_mel_noisy", wav_mel_noisy, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_mel_postnet_noisy", wav_mel_postnet_noisy, current_step, sample_rate=hp.sampling_rate) if current_step == 1 or current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, cl_a, cl_a_dat, m_l, m_p_l, m_n_l, m_p_n_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l + m_n_l + m_p_n_l\ + hp.dat_weight*(cl_a + cl_a_dat) val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/mel_noisy_loss', m_n_l, current_step) val_logger.add_scalar('Loss/mel_postnet_noisy_loss', m_p_n_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) val_logger.add_scalar('Loss/dat_clean_loss', cl_a, current_step) val_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def train(rank, args, hp, hp_str): # if hp.train.num_gpus > 1: # init_process_group(backend=hp.dist.dist_backend, init_method=hp.dist.dist_url, # world_size=hp.dist.world_size * hp.train.num_gpus, rank=rank) torch.cuda.manual_seed(hp.train.seed) device = torch.device('cuda:{:d}'.format(rank)) generator = Generator(hp.model.in_channels, hp.model.out_channels).to(device) specd = SpecDiscriminator().to(device) msd = MultiScaleDiscriminator().to(device) stft_loss = MultiResolutionSTFTLoss() if rank == 0: print(generator) os.makedirs(hp.logs.chkpt_dir, exist_ok=True) print("checkpoints directory : ", hp.logs.chkpt_dir) if os.path.isdir(hp.logs.chkpt_dir): cp_g = scan_checkpoint(hp.logs.chkpt_dir, 'g_') cp_do = scan_checkpoint(hp.logs.chkpt_dir, 'do_') steps = 0 if cp_g is None or cp_do is None: state_dict_do = None last_epoch = -1 else: state_dict_g = load_checkpoint(cp_g, device) state_dict_do = load_checkpoint(cp_do, device) generator.load_state_dict(state_dict_g['generator']) specd.load_state_dict(state_dict_do['specd']) msd.load_state_dict(state_dict_do['msd']) steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if hp.train.num_gpus > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) specd = DistributedDataParallel(specd, device_ids=[rank]).to(device) msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) optim_g = torch.optim.AdamW( generator.parameters(), hp.train.adamG.lr, betas=[hp.train.adamG.beta1, hp.train.adamG.beta2]) optim_d = torch.optim.AdamW( itertools.chain(msd.parameters(), specd.parameters()), hp.train.adamD.lr, betas=[hp.train.adamD.beta1, hp.train.adamD.beta2]) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hp.train.adam.lr_decay, last_epoch=last_epoch) training_filelist, validation_filelist = get_dataset_filelist(args) trainset = MelDataset(training_filelist, hp.data.input_wavs, hp.data.output_wavs, hp.audio.segment_length, hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.hop_length, hp.audio.win_length, hp.audio.sampling_rate, hp.audio.mel_fmin, hp.audio.mel_fmax, n_cache_reuse=0, shuffle=False if hp.train.num_gpus > 1 else True, fmax_loss=None, device=device) train_sampler = DistributedSampler( trainset) if hp.train.num_gpus > 1 else None train_loader = DataLoader(trainset, num_workers=hp.train.num_workers, shuffle=False, sampler=train_sampler, batch_size=hp.train.batch_size, pin_memory=True, drop_last=True) if rank == 0: validset = MelDataset(validation_filelist, hp.data.input_wavs, hp.data.output_wavs, hp.audio.segment_length, hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.hop_length, hp.audio.win_length, hp.audio.sampling_rate, hp.audio.mel_fmin, hp.audio.mel_fmax, split=False, shuffle=False, n_cache_reuse=0, fmax_loss=None, device=device) validation_loader = DataLoader(validset, num_workers=1, shuffle=False, sampler=None, batch_size=1, pin_memory=True, drop_last=True) sw = SummaryWriter(os.path.join(hp.logs.chkpt_dir, 'logs')) generator.train() specd.train() msd.train() with_postnet = False for epoch in range(max(0, last_epoch), args.training_epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if hp.train.num_gpus > 1: train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): if rank == 0: start_b = time.time() if steps > hp.train.postnet_start_steps: with_postnet = True x, y, file, _, y_mel_loss = batch x = torch.autograd.Variable(x.to(device, non_blocking=True)) y = torch.autograd.Variable(y.to(device, non_blocking=True)) y_mel_loss = torch.autograd.Variable( y_mel_loss.to(device, non_blocking=True)) # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) x = x.unsqueeze(1) y = y.unsqueeze(1) before_y_g_hat, y_g_hat = generator(x, with_postnet) if y_g_hat is not None: y_g_hat_mel = mel_spectrogram( y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) if steps > hp.train.discriminator_train_start_steps: for _ in range(hp.train.rep_discriminator): optim_d.zero_grad() # SpecD y_df_hat_r, y_df_hat_g, _, _ = specd( y_mel_loss, y_g_hat_mel.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( y_df_hat_r, y_df_hat_g) # MSD y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all.backward() optim_d.step() before_y_g_hat_mel = mel_spectrogram( before_y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) # Generator optim_g.zero_grad() # L1 Mel-Spectrogram Loss # before_loss_mel = F.l1_loss(y_mel_loss, before_y_g_hat_mel) sc_loss, mag_loss = stft_loss( before_y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1)) before_loss_mel = sc_loss + mag_loss # L1 Sample Loss before_loss_sample = F.l1_loss(y, before_y_g_hat) loss_gen_all = before_loss_mel + before_loss_sample if y_g_hat is not None: # L1 Mel-Spectrogram Loss # loss_mel = F.l1_loss(y_mel_loss, y_g_hat_mel) sc_loss_, mag_loss_ = stft_loss( y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1)) loss_mel = sc_loss_ + mag_loss_ # L1 Sample Loss loss_sample = F.l1_loss(y, y_g_hat) loss_gen_all += loss_mel + loss_sample if steps > hp.train.discriminator_train_start_steps: y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = specd( y_mel_loss, y_g_hat_mel) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_all += hp.model.lambda_adv * ( loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f) loss_gen_all.backward() optim_g.step() if rank == 0: # STDOUT logging if steps % args.stdout_interval == 0: with torch.no_grad(): mel_error = F.l1_loss(y_mel_loss, before_y_g_hat_mel).item() sample_error = F.l1_loss(y, before_y_g_hat) print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, Sample Error: {:4.3f}, ' 'Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.format( steps, loss_gen_all, sample_error, mel_error, time.time() - start_b)) # checkpointing if steps % hp.logs.save_interval == 0 and steps != 0: checkpoint_path = "{}/g_{:08d}".format( hp.logs.chkpt_dir, steps) save_checkpoint( checkpoint_path, { 'generator': (generator.module if hp.train.num_gpus > 1 else generator).state_dict() }) checkpoint_path = "{}/do_{:08d}".format( hp.logs.chkpt_dir, steps) save_checkpoint( checkpoint_path, { 'specd': (specd.module if hp.train.num_gpus > 1 else specd).state_dict(), 'msd': (msd.module if hp.train.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch, 'hp_str': hp_str }) # Tensorboard summary logging if steps % hp.logs.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", mel_error, steps) # Validation if steps % hp.logs.validation_interval == 0: # and steps != 0: generator.eval() torch.cuda.empty_cache() val_err_tot = 0 with torch.no_grad(): for j, batch in enumerate(validation_loader): x, y, file, y_mel, y_mel_loss = batch x = x.unsqueeze(1) y = y.unsqueeze(1).to(device) before_y_g_hat, y_g_hat = generator(x.to(device)) y_mel_loss = torch.autograd.Variable( y_mel_loss.to(device, non_blocking=True)) y_g_hat_mel = mel_spectrogram( before_y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) val_err_tot += F.l1_loss(y_mel_loss, y_g_hat_mel).item() val_err_tot += F.l1_loss(y, before_y_g_hat).item() if y_g_hat is not None: val_err_tot += F.l1_loss(y, y_g_hat).item() if j <= 4: if steps == 0: sw.add_audio('gt_noise/y_{}'.format(j), x[0], steps, hp.audio.sampling_rate) sw.add_audio('gt_clean/y_{}'.format(j), y[0], steps, hp.audio.sampling_rate) sw.add_figure( 'gt/y_spec_clean_{}'.format(j), plot_spectrogram(y_mel[0]), steps) sw.add_audio('generated/y_hat_{}'.format(j), before_y_g_hat[0], steps, hp.audio.sampling_rate) if y_g_hat is not None: sw.add_audio( 'generated/y_hat_after_{}'.format(j), y_g_hat[0], steps, hp.audio.sampling_rate) y_hat_spec = mel_spectrogram( before_y_g_hat.squeeze(1), hp.audio.filter_length, hp.audio.n_mel_channels, hp.audio.sampling_rate, hp.audio.hop_length, hp.audio.win_length, hp.audio.mel_fmin, None) sw.add_figure( 'generated/y_hat_spec_{}'.format(j), plot_spectrogram( y_hat_spec.squeeze(0).cpu().numpy()), steps) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/mel_spec_error", val_err, steps) generator.train() steps += 1 # scheduler_g.step() # scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))
def main(): # Reproducibility np.random.seed(12345) torch.manual_seed(12345) # Preparation config = get_parameters() # Logging configuration writer = None if config.tensorboard: path_tensorboard = f'{config.logging_dir}/{config.experiment_description}' if config.debug_mode: # Clear tensorboard when debugging if os.path.exists(path_tensorboard): shutil.rmtree(path_tensorboard) writer = SummaryWriter(path_tensorboard) data_loader_train, data_loader_valid, data_loader_test = get_data(config) if config.use_time_freq: transforms = get_time_frequency_transform(config) else: transforms = None # ===================================================================== # Visualize some data tmp_audio = None tmp_spec = None tmp_data, targets, _ = data_loader_train.dataset[ 0] # audio is [channels, timesteps] # Is the data audio or image? if len(tmp_data.shape) == 2: tmp_audio = tmp_data else: tmp_spec = tmp_data if config.use_time_freq: tmp_spec = transforms( tmp_audio) # spec is [channels, freq_bins, frames] if tmp_spec is not None: utils.show_spectrogram(tmp_spec, config) if writer is not None: if tmp_audio is not None: # Store 5 secs of audio ind = tmp_audio.shape[-1] if tmp_audio.shape[ -1] <= 5 * config.original_fs else 5 * config.original_fs writer.add_audio('input_audio', tmp_audio[:, 0:ind], None, config.original_fs) tmp_audios = [] fnames = [] for i in range(4): aud, _, fn = data_loader_train.dataset.dataset[i] fnames.append(fn) tmp_audios.append(aud) writer.add_figure( 'input_waveform', utils.show_waveforms_batch(tmp_audios, fnames, config), None) # Analyze some spectrograms if tmp_spec is not None: img_tform = tforms_vision.Compose([ tforms_vision.ToPILImage(), tforms_vision.ToTensor(), ]) writer.add_image('input_spec', img_tform(tmp_spec), None) # Raw tensor writer.add_figure('input_spec_single', utils.show_spectrogram(tmp_spec, config), None) # Librosa if config.use_time_freq: tmp_specs = [] fnames = [] for i in range(4): aud, _, fn = data_loader_train.dataset.dataset[i] tmp_specs.append(transforms(aud)) fnames.append(fn) writer.add_figure( 'input_spec_batch', utils.show_spectrogram_batch(tmp_specs, fnames, config), None) writer.add_figure('input_spec_histogram', utils.get_spectrogram_histogram(tmp_specs), None) del tmp_specs, fnames, aud, fn, i # Class Histograms if not config.dataset_skip_class_hist: fig_classes = utils.get_class_histograms( data_loader_train.dataset, data_loader_valid.dataset, data_loader_test.dataset, one_hot_encoder=utils.OneHot if config.dataset == 'MNIST' else None, data_limit=200 if config.debug_mode else None) if writer is not None: writer.add_figure('class_histogram', fig_classes, None) # ===================================================================== # Train and Test solver = Solver(data_loader_train, data_loader_valid, data_loader_test, config, writer, transforms) solver.train() scores, true_class, pred_scores = solver.test() # ===================================================================== # Save results np.save(open(os.path.join(config.result_dir, 'true_class.npy'), 'wb'), true_class) np.save(open(os.path.join(config.result_dir, 'pred_scores.npy'), 'wb'), pred_scores) utils.compare_predictions(true_class, pred_scores, config.result_dir) if writer is not None: writer.close()
class DurationExtractor(nn.Module): """The teacher model for duration extraction""" def __init__( self, adam_lr=0.002, warmup_epochs=30, init_scale=0.25, guided_att_sigma=0.3, device='cuda' ): super(DurationExtractor, self).__init__() self.txt_encoder = ConvTextEncoder() self.audio_encoder = ConvAudioEncoder() self.audio_decoder = ConvAudioDecoder() self.attention = ScaledDotAttention() self.collate = Collate(device=device) # optim self.optimizer = torch.optim.Adam(self.parameters(), lr=adam_lr) self.scheduler = NoamScheduler(self.optimizer, warmup_epochs, init_scale) # losses self.loss_l1 = l1_masked self.loss_att = GuidedAttentionLoss(guided_att_sigma) # device self.device=device self.to(self.device) print(f'Model sent to {self.device}') # helper vars self.checkpoint = None self.epoch = 0 self.step = 0 repo = git.Repo(search_parent_directories=True) self.git_commit = repo.head.object.hexsha def to_device(self, device): print(f'Sending network to {device}') self.device = device self.to(device) return self def save(self): if self.checkpoint is not None: os.remove(self.checkpoint) self.checkpoint = os.path.join(self.logger.log_dir, f'{time.strftime("%Y-%m-%d")}_checkpoint_step{self.step}.pth') torch.save( { 'epoch': self.epoch, 'step': self.step, 'state_dict': self.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'git_commit': self.git_commit }, self.checkpoint) def load(self, checkpoint): checkpoint = torch.load(checkpoint) self.epoch = checkpoint['epoch'] self.step = checkpoint['step'] self.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) commit = checkpoint['git_commit'] if commit != self.git_commit: print(f'Warning: the loaded checkpoint was trained on commit {commit}, but you are on {self.git_commit}') self.checkpoint = None # prevent overriding old checkpoint return self def forward(self, phonemes, spectrograms, len_phonemes, training=False): """ :param phonemes: (batch, alphabet, time), padded phonemes :param spectrograms: (batch, freq, time), padded spectrograms :param len_phonemes: list of phoneme lengths :return: decoded_spectrograms, attention_weights """ spectrs = ZeroPad2d((0,0,1,0))(spectrograms)[:, :-1, :] # move this to encoder? keys, values = self.txt_encoder(phonemes) queries = self.audio_encoder(spectrs) att_mask = mask(shape=(len(keys), queries.shape[1], keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) queries += positional_encoding(queries.shape[-1], queries.shape[1], w=1).to(self.device) attention, weights = self.attention(queries, keys, values, mask=att_mask) decoded = self.audio_decoder(attention + queries) return decoded, weights def generating(self, mode): """Put the module into mode for sequential generation""" for module in self.children(): if hasattr(module, 'generating'): module.generating(mode) def generate(self, phonemes, len_phonemes, steps=False, window=3, spectrograms=None): """Sequentially generate spectrogram from phonemes If spectrograms are provided, they are used on input instead of self-generated frames (teacher forcing) If steps are provided with spectrograms, only 'steps' frames will be generated in supervised fashion Uses layer-level caching for faster inference. :param phonemes: Padded phoneme indices :param len_phonemes: Length of each sentence in `phonemes` (list of lengths) :param steps: How many steps to generate :param window: Window size for attention masking :param spectrograms: Padded spectrograms :return: Generated spectrograms """ self.generating(True) self.train(False) assert steps or (spectrograms is not None) steps = steps if steps else spectrograms.shape[1] with torch.no_grad(): phonemes = torch.as_tensor(phonemes) keys, values = self.txt_encoder(phonemes) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) pe = positional_encoding(hp.channels, steps, w=1).to(self.device) if spectrograms is None: dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device) else: input = ZeroPad2d((0, 0, 1, 0))(spectrograms)[:, :-1, :] weights, decoded = None, None if window is not None: shape = (len(phonemes), 1, phonemes.shape[-1]) idx = torch.zeros(len(phonemes), 1, phonemes.shape[-1]).to(phonemes.device) att_mask = idx_mask(shape, idx, window) else: att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) for i in range(steps): if spectrograms is None: queries = self.audio_encoder(dec) else: queries = self.audio_encoder(input[:, i:i+1, :]) if hp.positional_encoding: queries += pe[i] att, w = self.attention(queries, keys, values, att_mask) dec = self.audio_decoder(att + queries) weights = w if weights is None else torch.cat((weights, w), dim=1) decoded = dec if decoded is None else torch.cat((decoded, dec), dim=1) if window is not None: idx = torch.argmax(w, dim=-1).unsqueeze(2).float() att_mask = idx_mask(shape, idx, window) self.generating(False) return decoded, weights def generate_naive(self, phonemes, len_phonemes, steps=1, window=(0,1)): """Naive generation without layer-level caching for testing purposes""" self.train(False) with torch.no_grad(): phonemes = torch.as_tensor(phonemes) keys, values = self.txt_encoder(phonemes) if hp.positional_encoding: keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device) pe = positional_encoding(hp.channels, steps, w=1).to(self.device) dec = torch.zeros(len(phonemes), 1, hp.out_channels, device=self.device) weights = None att_mask = mask(shape=(len(phonemes), 1, keys.shape[1]), lengths=len_phonemes, dim=-1).to(self.device) for i in range(steps): print(i) queries = self.audio_encoder(dec) if hp.positional_encoding: queries += pe[i] att, w = self.attention(queries, keys, values, att_mask) d = self.audio_decoder(att + queries) d = d[:, -1:] w = w[:, -1:] weights = w if weights is None else torch.cat((weights, w), dim=1) dec = torch.cat((dec, d), dim=1) if window is not None: att_mask = median_mask(weights, window=window) return dec[:, 1:, :], weights def fit(self, batch_size, logdir, epochs=1, grad_clip=1, checkpoint_every=10): self.grad_clip = grad_clip self.logger = SummaryWriter(logdir) train_loader = self.train_dataloader(batch_size) valid_loader = self.val_dataloader(batch_size) # continue training from self.epoch if checkpoint loaded for e in range(self.epoch + 1, self.epoch + 1 + epochs): self.epoch = e train_losses = self._train_epoch(train_loader) valid_losses = self._validate(valid_loader) self.scheduler.step() self.logger.add_scalar('train/learning_rate', self.optimizer.param_groups[0]['lr'], self.epoch) if not e % checkpoint_every: self.save() print(f'Epoch {e} | Train - l1: {train_losses[0]}, guided_att: {train_losses[1]}| ' f'Valid - l1: {valid_losses[0]}, guided_att: {valid_losses[1]}|') def _train_epoch(self, dataloader): self.train() t_l1, t_att = 0, 0 for i, batch in enumerate(Bar(dataloader)): self.optimizer.zero_grad() spectrs, slen, phonemes, plen, text = batch s = add_random_noise(spectrs, hp.noise) s = degrade_some(self, s, phonemes, plen, hp.feed_ratio, repeat=hp.feed_repeat) s = frame_dropout(s, hp.replace_ratio) out, att_weights = self.forward(phonemes, s, plen) l1 = self.loss_l1(out, spectrs, slen) l_att = self.loss_att(att_weights, slen, plen) loss = l1 + l_att loss.backward() torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) self.optimizer.step() self.step += 1 t_l1 += l1.item() t_att += l_att.item() self.logger.add_scalar( 'batch/total', loss.item(), self.step ) # report average cost per batch self.logger.add_scalar('train/l1', t_l1 / i, self.epoch) self.logger.add_scalar('train/guided_att', t_att / i, self.epoch) return t_l1 / i, t_att / i def _validate(self, dataloader): self.eval() t_l1, t_att = 0,0 for i, batch in enumerate(dataloader): spectrs, slen, phonemes, plen, text = batch # generate sequentially out, att_weights = self.generate(phonemes, plen, steps=spectrs.shape[1], window=None) # generate in supervised fashion - for visualisation only with torch.no_grad(): out_s, att_s = self.forward(phonemes, spectrs, plen) l1 = self.loss_l1(out, spectrs, slen) l_att = self.loss_att(att_weights, slen, plen) t_l1 += l1.item() t_att += l_att.item() fig = display_spectr_alignment(out[-1, :slen[-1]], att_weights[-1][:slen[-1], :plen[-1]], out_s[-1, :slen[-1]], att_s[-1][:slen[-1], :plen[-1]], text[-1]) self.logger.add_figure(text[-1], fig, self.epoch) if not self.epoch % 10: spec = self.collate.norm.inverse(out[-1:]) # TODO: this fails if we do not standardize! sound, length = self.collate.stft.spec2wav(spec.transpose(1, 2), slen[-1:]) sound = sound[0, :length[0]] self.logger.add_audio(text[-1], sound.detach().cpu().numpy(), self.epoch, sample_rate=22050) # TODO: parameterize # report average cost per batch self.logger.add_scalar('valid/l1', t_l1 / i, self.epoch) self.logger.add_scalar('valid/guided_att', t_att / i, self.epoch) return t_l1/i, t_att/i def train_dataloader(self, batch_size): return DataLoader(AudioDataset(HPText.dataset, start_idx=0, end_idx=HPText.num_train, durations=False), batch_size=batch_size, collate_fn=self.collate, shuffle=True) def val_dataloader(self, batch_size): dataset = AudioDataset(HPText.dataset, start_idx=HPText.num_train, end_idx=HPText.num_valid, durations=False) return DataLoader(dataset, batch_size=batch_size, collate_fn=self.collate, shuffle=False, sampler=SequentialSampler(dataset))
def main(args): os.environ['KMP_WARNINGS'] = '0' torch.cuda.manual_seed_all(1) np.random.seed(0) print(args.model_name) print(args.alpha) # filter array num_features = [ args.features * i for i in range(1, args.levels + 2 + args.levels_without_sample) ] # 確定 輸出大小 target_outputs = int(args.output_size * args.sr) # 訓練才保存模型設定參數 # 設定teacher and student and student_for_backward 超參數 student_KD = Waveunet(args.channels, num_features, args.channels, levels=args.levels, encoder_kernel_size=args.encoder_kernel_size, decoder_kernel_size=args.decoder_kernel_size, target_output_size=target_outputs, depth=args.depth, strides=args.strides, conv_type=args.conv_type, res=args.res) KD_optimizer = Adam(params=student_KD.parameters(), lr=args.lr) print(25 * '=' + 'model setting' + 25 * '=') print('student_KD: ', student_KD.shapes) if args.cuda: student_KD = utils.DataParallel(student_KD) print("move student_KD to gpu\n") student_KD.cuda() state = {"step": 0, "worse_epochs": 0, "epochs": 0, "best_pesq": -np.Inf} if args.load_model is not None: print("Continuing full model from checkpoint " + str(args.load_model)) state = utils.load_model(student_KD, KD_optimizer, args.load_model, args.cuda) dataset = get_folds(args.dataset_dir, args.outside_test) log_dir, checkpoint_dir, result_dir = utils.mkdir_and_get_path(args) # print(model) if args.test is False: writer = SummaryWriter(log_dir) # set hypeparameter # printing hypeparameters info print(25 * '=' + 'printing hypeparameters info' + 25 * '=') with open(os.path.join(log_dir, 'config.json'), 'w') as f: json.dump(args.__dict__, f, indent=5) print('saving commandline_args') student_size = sum(p.numel() for p in student_KD.parameters()) print('student_parameter count: ', str(student_size)) if args.teacher_model is not None: teacher_num_features = [ 24 * i for i in range(1, args.levels + 2 + args.levels_without_sample) ] teacher_model = Waveunet( args.channels, teacher_num_features, args.channels, levels=args.levels, encoder_kernel_size=args.encoder_kernel_size, decoder_kernel_size=args.decoder_kernel_size, target_output_size=target_outputs, depth=args.depth, strides=args.strides, conv_type=args.conv_type, res=args.res) if args.cuda: teacher_model = utils.DataParallel(teacher_model) teacher_model.cuda() # print("move teacher to gpu\n") student_size = sum(p.numel() for p in student_KD.parameters()) teacher_size = sum(p.numel() for p in teacher_model.parameters()) print('student_parameter count: ', str(student_size)) print('teacher_model_parameter count: ', str(teacher_size)) print(f'compression raito :{100*(student_size/teacher_size)}%') if args.teacher_model is not None: print("load teacher model" + str(args.teacher_model)) _ = utils.load_model(teacher_model, None, args.teacher_model, args.cuda) teacher_model.eval() # If not data augmentation, at least crop targets to fit model output shape crop_func = partial(crop, shapes=student_KD.shapes) ### DATASET train_data = SeparationDataset(dataset, "train", args.sr, args.channels, student_KD.shapes, False, args.hdf_dir, audio_transform=crop_func) val_data = SeparationDataset(dataset, "test", args.sr, args.channels, student_KD.shapes, False, args.hdf_dir, audio_transform=crop_func) dataloader = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, worker_init_fn=utils.worker_init_fn, pin_memory=True) # Set up the loss function if args.loss == "L1": criterion = nn.L1Loss() elif args.loss == "L2": criterion = nn.MSELoss() else: raise NotImplementedError("Couldn't find this loss!") My_criterion = customLoss() ### TRAINING START print('TRAINING START') batch_num = (len(train_data) // args.batch_size) while state["epochs"] < 100: # if state["epochs"]<10: # args.alpha=1 # else: # args.alpha=0 # print('fix alpha:',args.alpha) memory_alpha = [] print("epoch:" + str(state["epochs"])) student_KD.train() # monitor_value avg_origin_loss = 0 with tqdm(total=len(dataloader)) as pbar: for example_num, (x, targets) in enumerate(dataloader): if args.cuda: x = x.cuda() targets = targets.cuda() if args.teacher_model is not None: # Set LR for this iteration #print('base_model from KD') utils.set_cyclic_lr(KD_optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr) _, avg_student_KD_loss = utils.compute_loss( student_KD, x, targets, criterion, compute_grad=False) KD_optimizer.zero_grad() KD_outputs, KD_hard_loss, KD_loss, KD_soft_loss = utils.KD_compute_loss( student_KD, teacher_model, x, targets, My_criterion, alpha=args.alpha, compute_grad=True, KD_method=args.KD_method) KD_optimizer.step() # calculate backwarded model MSE avg_origin_loss += avg_student_KD_loss / batch_num # add to tensorboard writer.add_scalar("KD_loss", KD_loss, state["step"]) writer.add_scalar("KD_hard_loss", KD_hard_loss, state["step"]) writer.add_scalar("KD_soft_loss", KD_soft_loss, state["step"]) else: # no KD training utils.set_cyclic_lr(KD_optimizer, example_num, len(train_data) // args.batch_size, args.cycles, args.min_lr, args.lr) KD_optimizer.zero_grad() KD_outputs, KD_hard_loss = utils.compute_loss( student_KD, x, targets, nn.MSELoss(), compute_grad=True) KD_optimizer.step() avg_origin_loss += KD_hard_loss / batch_num writer.add_scalar("student_KD_loss", KD_hard_loss, state["step"]) ### save wav #### if example_num % args.example_freq == 0: input_centre = torch.mean( x[0, :, student_KD.shapes["output_start_frame"]: student_KD.shapes["output_end_frame"]], 0) # Stereo not supported for logs yet writer.add_audio("input:", input_centre, state["step"], sample_rate=args.sr) writer.add_audio("pred:", torch.mean(KD_outputs[0], 0), state["step"], sample_rate=args.sr) writer.add_audio("target", torch.mean(targets[0], 0), state["step"], sample_rate=args.sr) state["step"] += 1 pbar.update(1) # VALIDATE val_loss, val_metrics = validate(args, student_KD, criterion, val_data) print("ori VALIDATION FINISHED: LOSS: " + str(val_loss)) writer.add_scalar("avg_origin_loss", avg_origin_loss, state["epochs"]) writer.add_scalar("val_enhance_pesq", val_metrics[0], state["epochs"]) writer.add_scalar("val_improve_pesq", val_metrics[1], state["epochs"]) writer.add_scalar("val_enhance_stoi", val_metrics[2], state["epochs"]) writer.add_scalar("val_improve_stoi", val_metrics[3], state["epochs"]) writer.add_scalar("val_enhance_SISDR", val_metrics[4], state["epochs"]) writer.add_scalar("val_improve_SISDR", val_metrics[5], state["epochs"]) # writer.add_scalar("val_COPY_pesq",val_metrics_copy[0], state["epochs"]) writer.add_scalar("val_loss", val_loss, state["epochs"]) # Set up training state dict that will also be saved into checkpoints checkpoint_path = os.path.join( checkpoint_dir, "checkpoint_" + str(state["epochs"])) if val_metrics[0] < state["best_pesq"]: state["worse_epochs"] += 1 else: print("MODEL IMPROVED ON VALIDATION SET!") state["worse_epochs"] = 0 state["best_pesq"] = val_metrics[0] state["best_checkpoint"] = checkpoint_path # CHECKPOINT print("Saving model...") utils.save_model(student_KD, KD_optimizer, state, checkpoint_path) print('dump alpha_memory') with open(os.path.join(log_dir, 'alpha_' + str(state["epochs"])), "wb") as fp: #Pickling pickle.dump(memory_alpha, fp) state["epochs"] += 1 writer.close() info = args.model_name path = os.path.join(result_dir, info) else: PATH = args.load_model.split("/") info = PATH[-3] + "_" + PATH[-1] if (args.outside_test == True): info += "_outside_test" print(info) path = os.path.join(result_dir, info) #### TESTING #### # Test loss print("TESTING") # eval metrics _ = utils.load_model(student_KD, KD_optimizer, state["best_checkpoint"], args.cuda) test_metrics = evaluate(args, dataset["test"], student_KD) test_pesq = test_metrics['pesq'] test_stoi = test_metrics['stoi'] test_SISDR = test_metrics['SISDR'] test_noise = test_metrics['noise'] if not os.path.exists(path): os.makedirs(path) utils.save_result(test_pesq, path, "pesq") utils.save_result(test_stoi, path, "stoi") utils.save_result(test_SISDR, path, "SISDR") utils.save_result(test_noise, path, "noise")
class TacotronTrainer: TRAIN_STAGE = 'train' VAL_STAGE = 'val' VERSION_FORMAT = 'VERSION_{}' MODEL_SAVE_FORMAT = 'version_{version:03}_model_{step:010}.pth' def __init__(self, batch_size: int = 32, num_epoch: int = 100, train_split: float = 0.9, log_interval: int = 1000, log_audio_factor: int = 5, lr: float = 0.001, num_data: int = None, log_root: str = './tb_logs', save_root: str = './checkpoints', num_workers: int = 4, version: int = None, num_test_samples: int = 5): """ Initialize tacotron trainer Args: batch_size: batch size num_epoch: total number of epochs to train train_split: train ratio of train-val split log_interval: interval for test sample logging to tensorboard in epoch unit log_audio_factor: number of log_interval for logging audio which requires quite a lot of overhead num_data: number of datapoints to load in the dataset log_root: root directory for the tensorboard logging save_root: root directory for saving model num_workers: number of workers for dataloader version: version of training num_test_samples: number of test samples to generate for each logging """ if not os.path.exists(log_root): os.makedirs(log_root) if not os.path.exists(save_root): os.makedirs(save_root) is_cuda = torch.cuda.is_available() self.device = torch.device('cuda' if is_cuda else 'cpu') self.train_split = train_split self.epoch_num = num_epoch self.splitted_dataset = self.__split_dataset( TorchLJSpeechDataset(num_data=num_data)) self.dataloaders = self.__get_dataloaders( batch_size, num_workers=num_workers) self.tacotron = Tacotron() self.tacotron.to(self.device) self.loss = TacotronLoss() self.optimizer = Adam(self.tacotron.parameters(), lr=lr) self.lr_scheduler = StepLR( optimizer=self.optimizer, step_size=10000, gamma=0.9) if version is None: versions = os.listdir(log_root) if not versions: self.version = 0 else: self.version = max([int(ver[-1]) for ver in versions]) + 1 log_dir = os.path.join( log_root, self.VERSION_FORMAT.format(self.version)) if os.path.exists(log_dir): os.remove(log_dir) self.logger = SummaryWriter(log_dir) self.save_root = save_root self.log_interval = log_interval self.log_audio_factor = log_audio_factor self.global_step = 0 self.running_count = {self.TRAIN_STAGE: 0, self.VAL_STAGE: 0} self.running_loss = {self.TRAIN_STAGE: 0, self.VAL_STAGE: 0} self.sample_indices = list(range(num_test_samples)) def fit_from_checkpoint(self, checkpoint_file: str): self.tacotron.load(checkpoint_file, self.device) self.fit() def fit(self): for epoch in tqdm.tqdm(range(self.epoch_num), total=self.epoch_num, desc='Epoch'): self.__run_epoch(epoch) def __run_epoch(self, epoch: int): # reset running loss and count after each epoch self.__reset_loss() self.__reset_count() for stage, dataloader in self.dataloaders.items(): prog_bar = tqdm.tqdm(dataloader, desc=f'{stage.capitalize()} in progress', total=len(dataloader)) for batch in dataloader: self.__run_step(batch, stage, prog_bar) # epoch vs global step self.logger.add_scalar('epoch', epoch, global_step=self.global_step) # add loss to logger loss_dict = {stage: self.__calculate_mean_loss(stage) for stage in self.running_loss} self.logger.add_scalars('loss', loss_dict, global_step=epoch) def __run_step(self, batch: TorchLJSpeechBatch, stage: str, prog_bar: tqdm.tqdm): if stage == self.TRAIN_STAGE: self.tacotron.train() self.optimizer.zero_grad() else: self.tacotron.eval() batch = batch.to(self.device) output = self.tacotron.forward_train(batch) loss_val = self.loss(batch.mel_spec, output.pred_mel_spec, batch.lin_spec, output.pred_lin_spec) self.running_loss[stage] += loss_val.item() * batch.mel_spec.size(0) self.running_count[stage] += batch.mel_spec.size(0) if stage == self.TRAIN_STAGE: loss_val.backward() self.optimizer.step() self.lr_scheduler.step() if self.global_step % self.log_interval == 0: self.logger.add_scalar('training_loss', self.__calculate_mean_loss(stage), global_step=self.global_step) log_audio = False if self.global_step % (self.log_interval * self.log_audio_factor) == 0: log_audio = True sample_results = self.__get_sample_results() for sample_result in sample_results: self.__log_sample_results( self.global_step, sample_result, log_audio=log_audio) self.tacotron.train() save_file = os.path.join( self.save_root, self.MODEL_SAVE_FORMAT.format( version=self.version, step=self.global_step) ) torch.save(self.tacotron.state_dict(), save_file) self.global_step += 1 prog_bar.update() prog_bar.set_postfix( {'Running Loss': f'{self.__calculate_mean_loss(stage):.3f}'}) def __log_sample_results(self, steps: int, sample_result: SampleResult, log_mel: bool = True, log_spec: bool = True, log_attention: bool = True, log_audio: bool = True) -> None: """ Log the sample results into tensorboard Args: steps: current step sample_result: sample result to log log_mel: if True, log mel spectrogram log_spec: if True, log spectrogram log_attention: if True, log attention log_audio: if True, log audio """ if log_mel: title = f'Log Mel Spectrogram, Step:{steps}, ' \ f'Uid: {sample_result.uid}' fig = self.__get_spec_plot( pred_spec=sample_result.pred_mel_spec, truth_spec=sample_result.truth_mel_spec, suptitle=title, ylabel='Mel') img_tensor = self.__get_plot_tensor(fig) tag = f'mel_spec/{sample_result.uid}' self.logger.add_image(tag, img_tensor, global_step=steps) if log_spec: title = f'Log Spectrogram, Step:{steps}, ' \ f'Uid: {sample_result.uid}' fig = self.__get_spec_plot( pred_spec=sample_result.pred_lin_spec, truth_spec=sample_result.truth_lin_spec, suptitle=title, ylabel='DFT bins') img_tensor = self.__get_plot_tensor(fig) tag = f'lin_spec/{sample_result.uid}' self.logger.add_image(tag, img_tensor, global_step=steps) if log_attention: title = f'Attention Weight, Epoch :{steps}, ' \ f'Uid: {sample_result.uid}' fig = self.__get_attention_plot( title=title, attention_weight=sample_result.attention_weight) img_tensor = self.__get_plot_tensor(fig) tag = f'attention/{sample_result.uid}' self.logger.add_image(tag, img_tensor, global_step=steps) if log_audio: pred_tag = f'audio/{sample_result.uid}_predicted' truth_tag = f'audio/{sample_result.uid}_truth' self.logger.add_audio( tag=pred_tag, snd_tensor=torch.from_numpy( sample_result.pred_audio).unsqueeze(1), # add channel dim global_step=steps, sample_rate=AudioProcessParam.sr ) self.logger.add_audio( tag=truth_tag, snd_tensor=torch.from_numpy( sample_result.truth_audio).unsqueeze(1), # add channel dim global_step=steps, sample_rate=AudioProcessParam.sr ) def __get_sample_results(self) -> List[SampleResult]: """ Get sample results to show in tensorboard, including 1. Predicted and ground truth spectrogram pairs 2. Predicted and ground truth mel spectrogram pairs 3. Predicted and ground truth audio pairs 4. Attention weight Returns: list of sample results """ val_dataset = self.splitted_dataset[self.VAL_STAGE] self.tacotron.eval() test_insts = [] with torch.no_grad(): for subset_i in self.sample_indices: datapoint: TorchLJSpeechData = val_dataset[subset_i] datapoint: TorchLJSpeechBatch = datapoint.add_batch_dim() datapoint = datapoint.to(self.device) ds_idx = val_dataset.indices[subset_i] uid = val_dataset.dataset.uids[ds_idx] # Transcription transcription = val_dataset.dataset.uid_to_transcription[uid] wav_filepath = os.path.join( val_dataset.dataset.wav_save_dir, f'{uid}.wav') truth_audio = AudioProcessingHelper.load_audio(wav_filepath) taco_output = self.tacotron.forward_train(datapoint) spec = taco_output.pred_lin_spec.squeeze(0).cpu().numpy().T pred_audio = AudioProcessingHelper.spec2audio(spec) test_insts.append( SampleResult( uid=uid, transcription=transcription, truth_lin_spec=datapoint.lin_spec.squeeze(0).cpu().numpy().T, pred_lin_spec=taco_output.pred_lin_spec.squeeze(0).cpu().numpy().T, truth_mel_spec=datapoint.mel_spec.squeeze(0).cpu().numpy().T, pred_mel_spec=taco_output.pred_mel_spec.squeeze(0).cpu().numpy().T, attention_weight=taco_output.attention_weight.squeeze(0).cpu().numpy(), truth_audio=truth_audio, pred_audio=pred_audio ) ) return test_insts @staticmethod def __get_attention_plot( title: str, attention_weight: np.ndarray) -> plt.Figure: """ Get figure handle for attention plot Args: title: title of the plot attention_weight: attention weight to plot Returns: figure object """ fig = plt.figure(figsize=(6, 5), dpi=80) plt.title(title) plt.imshow(attention_weight, aspect='auto') plt.colorbar() plt.xlabel('Encoder seq') plt.ylabel('Decoder seq') plt.gca().invert_yaxis() # Let the x, y axis start from the left-bottom corner plt.close(fig) return fig @staticmethod def __get_spec_plot(pred_spec: np.ndarray, truth_spec: np.ndarray, suptitle: str, ylabel: str) -> plt.Figure: """ Get a juxtaposition two spectrograms with appropriate title Args: pred_spec: predicted spectrogram truth_spec: ground truth spectrogram suptitle: title of the plot ylabel: unit of frequency axis of the spectrograms Returns: figure object """ vmin = min(np.min(truth_spec), np.min(pred_spec)) vmax = max(np.max(truth_spec), np.max(pred_spec)) fig = plt.figure(figsize=(11, 5), dpi=80) plt.suptitle(suptitle) ax1 = plt.subplot(121) plt.title('Ground Truth') plt.xlabel('Frame') plt.ylabel(ylabel) plt.imshow(truth_spec, vmin=vmin, vmax=vmax, aspect='auto') plt.gca().invert_yaxis() # let the x, y axis start from the left-bottom corner ax2 = plt.subplot(122) plt.title('Predicted') plt.xlabel('Frame') im = plt.imshow(pred_spec, vmin=vmin, vmax=vmax, aspect='auto') plt.gca().invert_yaxis() # let the x, y axis start from the left-bottom corner fig.tight_layout() fig.colorbar(im, ax=[ax1, ax2]) plt.close(fig) return fig @staticmethod def __get_plot_tensor(fig) -> torch.Tensor: """ Get tensor for the given figure object Args: fig: the figure object to convert into tensor Returns: tensor of the figure """ buf = io.BytesIO() fig.savefig(buf, format='jpeg') buf.seek(0) image = PIL.Image.open(buf) image = ToTensor()(image) return image def __calculate_mean_loss(self, stage: str) -> float: """ Calculate mean loss for given stage (train/val) Args: stage: train/val Returns: mean loss """ return self.running_loss[stage] / self.running_count[stage] def __reset_loss(self) -> None: self.running_loss = {self.TRAIN_STAGE: 0, self.VAL_STAGE: 0} def __reset_count(self) -> None: self.running_count = {self.TRAIN_STAGE: 0, self.VAL_STAGE: 0} def __split_dataset(self, dataset: TorchLJSpeechDataset) -> Dict[str, Subset]: """ Split the dataset into train/validation set Args: dataset: dataset to split Returns: splitted dataset """ num_train_data = int(len(dataset) * self.train_split) num_val_data = len(dataset) - num_train_data train_dataset, val_dataset = random_split( dataset, [num_train_data, num_val_data]) return {self.TRAIN_STAGE: train_dataset, self.VAL_STAGE: val_dataset} def __get_dataloaders( self, batch_size: int, num_workers: int) -> Dict[str, DataLoader]: return {stage: DataLoader( dataset, shuffle=(stage == self.TRAIN_STAGE), collate_fn=TorchLJSpeechDataset.batch_tacotron_input, pin_memory=True, batch_size=batch_size, num_workers=num_workers) for stage, dataset in self.splitted_dataset.items() }
class Logger: _count = 0 def __init__(self, scrn=True, log_dir='', phase=''): super().__init__() self._logger = logging.getLogger('logger_{}'.format(Logger._count)) Logger._count += 1 self._logger.setLevel(logging.DEBUG) if scrn: self._scrn_handler = logging.StreamHandler() self._scrn_handler.setLevel(logging.INFO) self._scrn_handler.setFormatter( logging.Formatter(fmt=FORMAT_SHORT)) self._logger.addHandler(self._scrn_handler) if log_dir and phase: self.log_path = os.path.join( log_dir, '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format( phase, *localtime()[:6])) self.show_nl("log into {}\n\n".format(self.log_path)) self._file_handler = logging.FileHandler(filename=self.log_path) self._file_handler.setLevel(logging.DEBUG) self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG)) self._logger.addHandler(self._file_handler) self._writer = SummaryWriter(log_dir=os.path.join( log_dir, '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format( phase, *localtime()[:6]))) def show(self, *args, **kwargs): return self._logger.info(*args, **kwargs) def show_nl(self, *args, **kwargs): self._logger.info("") return self.show(*args, **kwargs) def dump(self, *args, **kwargs): return self._logger.debug(*args, **kwargs) def warning(self, *args, **kwargs): return self._logger.warning(*args, **kwargs) def error(self, *args, **kwargs): return self._logger.error(*args, **kwargs) # tensorboard def add_scalar(self, *args, **kwargs): return self._writer.add_scalar(*args, **kwargs) def add_scalars(self, *args, **kwargs): return self._writer.add_scalars(*args, **kwargs) def add_histogram(self, *args, **kwargs): return self._writer.add_histogram(*args, **kwargs) def add_image(self, *args, **kwargs): return self._writer.add_image(*args, **kwargs) def add_images(self, *args, **kwargs): return self._writer.add_images(*args, **kwargs) def add_figure(self, *args, **kwargs): return self._writer.add_figure(*args, **kwargs) def add_video(self, *args, **kwargs): return self._writer.add_video(*args, **kwargs) def add_audio(self, *args, **kwargs): return self._writer.add_audio(*args, **kwargs) def add_text(self, *args, **kwargs): return self._writer.add_text(*args, **kwargs) def add_graph(self, *args, **kwargs): return self._writer.add_graph(*args, **kwargs) def add_pr_curve(self, *args, **kwargs): return self._writer.add_pr_curve(*args, **kwargs) def add_custom_scalars(self, *args, **kwargs): return self._writer.add_custom_scalars(*args, **kwargs) def add_mesh(self, *args, **kwargs): return self._writer.add_mesh(*args, **kwargs) # def add_hparams(self, *args, **kwargs): # return self._writer.add_hparams(*args, **kwargs) def flush(self): return self._writer.flush() def close(self): return self._writer.close() def _grad_hook(self, grad, name=None, grads=None): grads.update({name: grad}) def watch_grad(self, model, layers): """ Add hooks to the specific layers. Gradients of these layers will save to self.grads :param model: :param layers: Except a list eg. layers=[0, -1] means to watch the gradients of the fist layer and the last layer of the model :return: """ assert layers if not hasattr(self, 'grads'): self.grads = {} self.grad_hooks = {} named_parameters = list(model.named_parameters()) for layer in layers: name = named_parameters[layer][0] handle = named_parameters[layer][1].register_hook( functools.partial(self._grad_hook, name=name, grads=self.grads)) self.grad_hooks.update(dict(name=handle)) def watch_grad_close(self): for _, handle in self.grad_hooks.items(): handle.remove() # remove the hook def add_grads(self, global_step=None, *args, **kwargs): """ Add gradients to tensorboard. You must call the method self.watch_grad before using this method! """ assert hasattr(self, 'grads'),\ "self.grads is nonexisent! You must call self.watch_grad before!" assert self.grads, "self.grads if empty!" for (name, grad) in self.grads.items(): self.add_histogram(tag=name, values=grad, global_step=global_step, *args, **kwargs) @staticmethod def make_desc(counter, total, *triples): desc = "[{}/{}]".format(counter, total) # The three elements of each triple are # (name to display, AverageMeter object, formatting string) for name, obj, fmt in triples: desc += (" {} {obj.val:" + fmt + "} ({obj.avg:" + fmt + "})").format(name, obj=obj) return desc
def main(): args = parse_args() root = Path(args.save_path) spec_type = args.spectrogram inference = args.inference load_root = Path(args.load_path) if args.load_path else None root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) writer = SummaryWriter(str(root)) ####################### # Load PyTorch Models # ####################### if (spec_type == "mel"): print('Mel spec selected') netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers ).cuda() #initialise generator with n mel channels fft = Audio2Mel(n_mel_channels=args.n_mel_channels).cuda() if (spec_type == "cqt"): print('CQT spec selected') netG = Generator( args.n_bins, args.ngf, args.n_residual_layers).cuda() #initialise generator with n bins fft = Audio2Cqt(n_bins=args.n_bins).cuda() else: print('No spectrogram specified, defaulting to CQT') netG = Generator(args.n_bins, args.ngf, args.n_residual_layers).cuda() fft = Audio2Cqt(n_bins=args.n_bins).cuda() netD = Discriminator( args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).cuda() #initialize discriminator print(netG) print(netD) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) if load_root and load_root.exists(): print('Loading weights') netG.load_state_dict(torch.load(load_root / "netG.pt")) optG.load_state_dict(torch.load(load_root / "optG.pt")) netD.load_state_dict(torch.load(load_root / "netD.pt")) optD.load_state_dict(torch.load(load_root / "optD.pt")) print('Weights loaded') ####################### # Create data loaders # ####################### if inference: print('Starting inference') st = time.time() with torch.no_grad(): x = torch.load('unseen.pt') x = x.cuda() pred_audio = netG(x) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_sample.wav"), 22050, pred_audio) writer.add_audio( "generated/sample_test.wav", pred_audio, global_step=None, sample_rate=22050, ) print('Finished inference') return train_set = AudioDataset(Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=22050) test_set = AudioDataset(Path(args.data_path) / "test_files.txt", 22050 * 4, sampling_rate=22050, augment=False) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] for i, x_t in enumerate(test_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() test_voc.append(s_t.cuda()) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), 22050, audio) writer.add_audio("original/sample_%d.wav" % i, audio, 0, sample_rate=22050) if i == args.n_test_samples - 1: break costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 steps = 0 for epoch in range(1, args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() # generate spectrogram x_pred_t = netG(s_t.cuda()) # generate audio from real spectrogram with torch.no_grad(): s_pred_t = fft(x_pred_t.detach() ) # get spectrogram from the audio we generated s_error = F.l1_loss(s_t, s_pred_t).item( ) # find loss between generated spectrogram from real audio and the spectrogram from audio by the generator ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.cuda().detach()) D_real = netD(x_t.cuda()) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.cuda()) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() ###################### # Update tensorboard # ###################### costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) writer.add_scalar("loss/discriminator", costs[-1][0], steps) writer.add_scalar("loss/generator", costs[-1][1], steps) writer.add_scalar("loss/feature_matching", costs[-1][2], steps) writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), 22050, pred_audio) writer.add_audio( "generated/sample_%d.wav" % i, pred_audio, epoch, sample_rate=22050, ) torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def train_fn(args, params): # Directory preparation exp_dir = makeExpDirs(args.results_dir, args.exp_name) # Automatic Mixed-Precision if args.optim != "no": import apex device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Vocoder(mel_channels=params["preprocessing"]["num_mels"], conditioning_channels=params["vocoder"]["conditioning_channels"], embedding_dim=params["vocoder"]["embedding_dim"], rnn_channels=params["vocoder"]["rnn_channels"], fc_channels=params["vocoder"]["fc_channels"], bits=params["preprocessing"]["bits"], hop_length=params["preprocessing"]["hop_length"], nc=args.nc, device=device ) model.to(device) print(model) optimizer = optim.Adam(model.parameters(), lr=params["vocoder"]["learning_rate"]) # Automatic Mixed-Precision if args.optim != "no": model, optimizer = apex.amp.initialize(model, optimizer, opt_level=args.optim) scheduler = optim.lr_scheduler.StepLR(optimizer, params["vocoder"]["schedule"]["step_size"], params["vocoder"]["schedule"]["gamma"]) if args.resume is not None: print(f"Resume checkpoint from: {args.resume}:") checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) global_step = checkpoint["step"] else: global_step = 0 train_dataset = VocoderDataset(meta_file=os.path.join(args.data_dir, "train.txt"), sample_frames=params["vocoder"]["sample_frames"], audio_slice_frames=params["vocoder"]["audio_slice_frames"], hop_length=params["preprocessing"]["hop_length"], bits=params["preprocessing"]["bits"]) train_dataloader = DataLoader(train_dataset, batch_size=params["vocoder"]["batch_size"], shuffle=True, num_workers=1, pin_memory=True) num_epochs = params["vocoder"]["num_steps"] // len(train_dataloader) + 1 start_epoch = global_step // len(train_dataloader) + 1 # Logger writer = SummaryWriter(exp_dir/"logs") # Add original utterance to TensorBoard if args.resume is None: with open(os.path.join(args.data_dir, "test.txt"), encoding="utf-8") as f: test_wavnpy_paths = [line.strip().split("|")[1] for line in f] for index, wavnpy_path in enumerate(test_wavnpy_paths): muraw_npy = np.load(wavnpy_path) wav_npy = mulaw_decode(muraw_npy, 2**params["preprocessing"]["bits"]) writer.add_audio("orig", torch.from_numpy(wav_npy), global_step=global_step, sample_rate=params["preprocessing"]["sample_rate"]) break for epoch in range(start_epoch, num_epochs + 1): running_loss = 0 for i, (audio, mels) in enumerate(tqdm(train_dataloader, leave=False), 1): audio, mels = audio.to(device), mels.to(device) output = model(audio[:, :-1], mels) loss = F.cross_entropy(output.transpose(1, 2), audio[:, 1:]) optimizer.zero_grad() # Automatic Mixed-Precision if args.optim != "no": with apex.amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() scheduler.step() running_loss += loss.item() average_loss = running_loss / i global_step += 1 if global_step % args.save_step == 0: save_checkpoint(model, optimizer, scheduler, global_step, exp_dir/"params", False) if global_step % params["vocoder"]["checkpoint_interval"] == 0: save_checkpoint(model, optimizer, scheduler, global_step, exp_dir/"params", True) if global_step % params["vocoder"]["generation_interval"] == 0: with open(os.path.join(args.data_dir, "test.txt"), encoding="utf-8") as f: test_mel_paths = [line.strip().split("|")[2] for line in f] for index, mel_path in enumerate(test_mel_paths): utterance_id = os.path.basename(mel_path).split(".")[0] # unsqueeze: insert in a batch mel = torch.FloatTensor(np.load(mel_path)).unsqueeze(0).to(device) output = model.generate(mel) path = exp_dir/"samples"/f"gen_{utterance_id}_model_steps_{global_step}.wav" save_wav(str(path), output, params["preprocessing"]["sample_rate"]) if index == 0: writer.add_audio("cnvt", torch.from_numpy(output), global_step=global_step, sample_rate=params["preprocessing"]["sample_rate"]) # finish a epoch writer.add_scalar("NLL", average_loss, global_step)
class Logger(object): def __init__(self, config, rank=0): self.rank = rank self.summary_writer = None if self.rank == 0: self.logdir = config.training_config.logdir self.continue_training = config.training_config.continue_training self.sample_rate = config.data_config.sample_rate if rank == 0 and not self.continue_training and os.path.exists( self.logdir): raise RuntimeError( f"You're trying to run training from scratch, " f"but logdir `{self.logdir} already exists. Remove it or specify new one.`" ) if rank == 0 and not self.continue_training: os.makedirs(self.logdir) self.summary_writer = SummaryWriter(config.training_config.logdir) self.save_model_config(config) def _log_losses(self, iteration, loss_stats: dict): for key, value in loss_stats.items(): self.summary_writer.add_scalar(key, value, iteration) def log_training(self, iteration, stats, verbose=True): if self.rank != 0: return stats = {f'training/{key}': value for key, value in stats.items()} self._log_losses(iteration, loss_stats=stats) show_message( f'Iteration: {iteration} | Losses: {[value for value in stats.values()]}', verbose=verbose) def log_test(self, epoch, stats, verbose=True): if self.rank != 0: return stats = {f'test/{key}': value for key, value in stats.items()} self._log_losses(epoch, loss_stats=stats) show_message( f'Epoch: {epoch} | Losses: {[value for value in stats.values()]}', verbose=verbose) def log_audios(self, iteration, audios: dict): if self.rank != 0: return for key, audio in audios.items(): self.summary_writer.add_audio(key, audio, iteration, sample_rate=self.sample_rate) def log_specs(self, iteration, specs: dict): if self.rank != 0: return for key, image in specs.items(): self.summary_writer.add_image(key, plot_tensor_to_numpy(image), iteration, dataformats='HWC') def save_model_config(self, config): if self.rank != 0: return with open(f'{self.logdir}/config.json', 'w') as f: json.dump(config.to_dict_type(), f) def save_checkpoint(self, iteration, model, optimizer=None): if self.rank != 0: return d = {} d['iteration'] = iteration d['model'] = model.state_dict() if not isinstance(optimizer, type(None)): d['optimizer'] = optimizer.state_dict() filename = f'{self.summary_writer.log_dir}/checkpoint_{iteration}.pt' torch.save(d, filename) def load_latest_checkpoint(self, model, optimizer=None): if not self.continue_training: raise RuntimeError( f"Trying to load the latest checkpoint from logdir {self.logdir}, " "but did not set `continue_training=true` in configuration.") model, optimizer, iteration = load_latest_checkpoint( self.logdir, model, optimizer) return model, optimizer, iteration
frame_sr=dataset.frame_sr, dtype=dtype, ) training = Training(model, dataset, batch_size=batch_size) # plot ref audio and waveform ref_audio, ref_f0, ref_f0_scaled, ref_lo = dataset.get_sample(0) ref_audio = ref_audio.numpy() # max 5 secs ref_audio = ref_audio[:dataset.audio_sr * 5] ref_f0 = ref_f0[:dataset.frame_sr * 5] ref_f0_scaled = ref_f0_scaled[:dataset.frame_sr * 5] ref_lo = ref_lo[:dataset.frame_sr * 5] writer.add_audio("Groundtruth", ref_audio, 0, dataset.audio_sr) fig = plot_wf(ref_audio, dataset.audio_sr) writer.add_figure("Waveform/Groundtruth", fig, 0, True) fig = plot_spectrogram(ref_audio, dataset.audio_sr) writer.add_figure("Spectrogram/Groundtruth", fig, 0, True) writer.flush() best_epoch_loss = torch.finfo(dtype).max while training.epoch < nb_epochs: # run training and test epoch training.train_epoch() training.test_epoch() # plot losses writer.add_scalar("Loss/Train", training.train_loss, training.epoch)
class TrainingProcessHandler(object): def __init__(self, data_folder="logs", model_folder="model", enable_iteration_progress_bar=False, model_save_key="loss", mlflow_tags=None, mlflow_parameters=None, enable_mlflow=True, mlflow_experiment_name="DDT"): if mlflow_tags is None: mlflow_tags = {} if mlflow_parameters is None: mlflow_parameters = {} self._name = None self._epoch_count = 0 self._iteration_count = 0 self._current_epoch = 0 self._current_iteration = 0 self._log_folder = data_folder self._iteration_progress_bar = None self._enable_iteration_progress_bar = enable_iteration_progress_bar self._epoch_progress_bar = None self._writer = None self._train_metrics = {} self._model = None self._model_folder = model_folder self._run_name = "" self.train_history = {} self.validation_history = {} self._model_save_key = model_save_key self._previous_model_save_metric = None if not os.path.exists(self._model_folder): os.mkdir(self._model_folder) if not os.path.exists(self._log_folder): os.mkdir(self._log_folder) self._audio_configs = {} self._global_epoch_step = 0 self._global_iteration_step = 0 if enable_mlflow: self._mlflow_handler = MlFlowHandler( experiment_name=mlflow_experiment_name, mlflow_tags=mlflow_tags, mlflow_parameters=mlflow_parameters) else: self._mlflow_handler = None self._artifacts = [] def setup_handler(self, name, model): self._name = name self._run_name = name + "_" + datetime.datetime.now().strftime( '%Y-%m-%d-%H-%M-%S') self._writer = SummaryWriter( os.path.join(self._log_folder, self._run_name)) self._model = model self._previous_model_save_metric = None self.train_history = {} self.validation_history = {} self._global_epoch_step = 0 self._global_iteration_step = 0 def start_callback(self, epoch_count, iteration_count, parameters=None): if parameters is None: parameters = {} self._epoch_count = epoch_count self._iteration_count = iteration_count self._current_epoch = 0 self._current_iteration = 0 self._epoch_progress_bar = tqdm(total=self._epoch_count) if self._enable_iteration_progress_bar: self._iteration_progress_bar = tqdm(total=self._iteration_count // self._epoch_count) if self._mlflow_handler is not None: self._mlflow_handler.start_callback(parameters) def epoch_callback(self, metrics, image_batches=None, figures=None, audios=None, texts=None): self._artifacts = [] for key, value in metrics.items(): self.validation_history.setdefault(key, []).append(value) self._write_epoch_metrics(metrics) if image_batches is not None: self._write_image_batches(image_batches) if figures is not None: self._write_figures(figures) if audios is not None: self._write_audios(audios) if texts is not None: self._write_texts(texts) if self._enable_iteration_progress_bar and self._epoch_count != self._current_epoch - 1: self._iteration_progress_bar.reset() self._epoch_progress_bar.update() self._epoch_progress_bar.set_postfix_str( self.metric_string("valid", metrics)) if self.should_save_model(metrics) and self._model is not None: torch.save( self._model.state_dict(), os.path.join(self._model_folder, f"{self._run_name}_checkpoint.pth")) self._current_epoch += 1 self._global_epoch_step += 1 if self._mlflow_handler is not None: self._mlflow_handler.epoch_callback(metrics, self._current_epoch, self._artifacts) def iteration_callback(self, metrics): for key, value in metrics.items(): self.train_history.setdefault(key, []).append(value) self._train_metrics = metrics self._write_iteration_metrics(metrics) if self._enable_iteration_progress_bar: self._iteration_progress_bar.set_postfix_str( self.metric_string("train", metrics)) self._iteration_progress_bar.update() self._current_iteration += 1 self._global_iteration_step += 1 def finish_callback(self, metrics): print(self.metric_string("test", metrics)) self._writer.close() if self._enable_iteration_progress_bar: self._iteration_progress_bar.close() self._epoch_progress_bar.close() if self._mlflow_handler is not None: self._mlflow_handler.finish_callback() @staticmethod def metric_string(prefix, metrics): result = "" for key, value in metrics.items(): result += "{} {} = {:>3.3f}, ".format(prefix, key, value) return result[:-2] def _write_epoch_metrics(self, validation_metrics): for key, value in validation_metrics.items(): self._writer.add_scalar(f"epoch/{key}", value, global_step=self._global_epoch_step) def _write_iteration_metrics(self, train_metrics): for key, value in train_metrics.items(): self._writer.add_scalar(f"iteration/{key}", value, global_step=self._global_iteration_step) def should_save_model(self, metrics): if self._model_save_key not in metrics.keys(): return True if self._previous_model_save_metric is None: self._previous_model_save_metric = metrics[self._model_save_key] return True if self._previous_model_save_metric > metrics[self._model_save_key]: self._previous_model_save_metric = metrics[self._model_save_key] return True return False def _write_image_batches(self, image_batches): for key, value in image_batches.items(): self._writer.add_images(key, value, self._global_epoch_step, dataformats="NHWC") def _write_figures(self, figures): for key, value in figures.items(): self._writer.add_figure(key, value, self._global_epoch_step) artifact_name = f"{self._log_folder}/{key}_{self._global_epoch_step:04d}.png" value.savefig(artifact_name) self._artifacts.append(artifact_name) def _write_audios(self, audios): for key, value in audios.items(): self._writer.add_audio(key, value, self._global_epoch_step, **self._audio_configs) def set_audio_configs(self, configs): self._audio_configs = configs def _write_texts(self, texts): for key, value in texts.items(): self._writer.add_text(key, value, self._global_epoch_step)
def trainG(args): root = Path(args['logging']['save_path']) load_root = Path( args['logging']['load_path']) if args['logging']['load_path'] else None root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) writer = SummaryWriter(str(root)) ####################### # Load PyTorch Models # ####################### netG = Generator(args['fft']['n_mel_channels'], args['Generator']['ngf'], args['Generator']['n_residual_layers'], ratios=args['Generator']['ratios']).cuda() if 'G_path' in args['Generator'] and args['Generator'][ 'G_path'] is not None: netG.load_state_dict( torch.load(args['Generator']['G_path'] / "netG.pt")) fft = Audio2Mel(n_mel_channels=args['fft']['n_mel_channels'], n_fft=args['fft']['n_fft'], hop_length=args['fft']['hop_length'], win_length=args['fft']['win_length'], sampling_rate=args['data']['sampling_rate'], mel_fmin=args['fft']['mel_fmin'], mel_fmax=args['fft']['mel_fmax']).cuda() print(netG) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args['optimizer']['lrG'], betas=args['optimizer']['betasG']) if load_root and load_root.exists(): netG.load_state_dict(torch.load(load_root / "netG.pt")) optG.load_state_dict(torch.load(load_root / "optG.pt")) print('checkpoints loaded') ####################### # Create data loaders # ####################### train_set = AudioDataset(Path(args['data']['data_path']) / "train_files_inv.txt", segment_length=args['data']['seq_len'], sampling_rate=args['data']['sampling_rate'], augment=['amp', 'flip', 'neg']) test_set = AudioDataset(Path(args['data']['data_path']) / "test_files_inv.txt", segment_length=args['data']['sampling_rate'] * 4, sampling_rate=args['data']['sampling_rate'], augment=None) train_loader = DataLoader(train_set, batch_size=args['dataloader']['batch_size'], num_workers=4, pin_memory=True, shuffle=True) test_loader = DataLoader(test_set, batch_size=1) ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] for i, x_t in enumerate(test_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() test_voc.append(s_t.cuda()) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), args['data']['sampling_rate'], audio) writer.add_audio("original/sample_%d.wav" % i, audio, 0, sample_rate=args['data']['sampling_rate']) if i == args['logging']['n_test_samples'] - 1: break costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 steps = 0 mr_stft_loss = MultiResolutionSTFTLoss().cuda() for epoch in range(1, args['train']['epochs'] + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() x_pred_t = netG(s_t.cuda()) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ################### # Train Generator # ################### loss_G = 0 sc, sm = mr_stft_loss(x_pred_t, x_t) loss_G = args['losses']['lambda_sc'] * sc + args['losses'][ 'lambda_sm'] * sm netG.zero_grad() loss_G.backward() optG.step() ###################### # Update tensorboard # ###################### costs.append([loss_G.item(), sc.item(), sm.item(), s_error]) writer.add_scalar("loss/generator", costs[-1][0], steps) writer.add_scalar("loss/convergence", costs[-1][1], steps) writer.add_scalar("loss/logmag", costs[-1][2], steps) writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) steps += 1 if steps % args['logging']['save_interval'] == 0: st = time.time() with torch.no_grad(): for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), args['data']['sampling_rate'], pred_audio) writer.add_audio( "generated/sample_%d.wav" % i, pred_audio, epoch, sample_rate=args['data']['sampling_rate'], ) torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netG.state_dict(), root / "best_netG.pt") print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args['logging']['log_interval'] == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args['logging']['log_interval'], np.asarray(costs).mean(0), )) costs = [] start = time.time()
def train(rank, a, h): if h.num_gpus > 1: init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) torch.cuda.manual_seed(h.seed) torch.cuda.set_device(rank) device = torch.device('cuda:{:d}'.format(rank)) generator = Generator(h).to(device) mpd = MultiPeriodDiscriminator().to(device) msd = MultiScaleDiscriminator().to(device) if rank == 0: print(generator) os.makedirs(a.checkpoint_path, exist_ok=True) print("checkpoints directory : ", a.checkpoint_path) if os.path.isdir(a.checkpoint_path): cp_g = scan_checkpoint(a.checkpoint_path, 'g_') cp_do = scan_checkpoint(a.checkpoint_path, 'do_') steps = 0 if cp_g is None or cp_do is None: state_dict_do = None last_epoch = -1 else: state_dict_g = load_checkpoint(cp_g, device) state_dict_do = load_checkpoint(cp_do, device) generator.load_state_dict(state_dict_g['generator']) mpd.load_state_dict(state_dict_do['mpd']) msd.load_state_dict(state_dict_do['msd']) steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if h.num_gpus > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) training_filelist, validation_filelist = get_dataset_filelist(a) trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, sampler=train_sampler, batch_size=h.batch_size, pin_memory=True, drop_last=True) if rank == 0: validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) validation_loader = DataLoader(validset, num_workers=1, shuffle=False, sampler=None, batch_size=1, pin_memory=True, drop_last=True) sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) generator.train() mpd.train() msd.train() for epoch in range(max(0, last_epoch), a.training_epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if h.num_gpus > 1: train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): if rank == 0: start_b = time.time() x, y, _, y_mel = batch x = torch.autograd.Variable(x.to(device, non_blocking=True)) y = torch.autograd.Variable(y.to(device, non_blocking=True)) y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) y = y.unsqueeze(1) y_g_hat = generator(x) y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) optim_d.zero_grad() # MPD y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( y_df_hat_r, y_df_hat_g) # MSD y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L1 Mel-Spectrogram Loss loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel loss_gen_all.backward() optim_g.step() if rank == 0: # STDOUT logging if steps % a.stdout_interval == 0: with torch.no_grad(): mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}' .format(steps, loss_gen_all, mel_error, time.time() - start_b)) # checkpointing if steps % a.checkpoint_interval == 0 and steps != 0: checkpoint_path = "{}/g_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'generator': (generator.module if h.num_gpus > 1 else generator).state_dict() }) checkpoint_path = "{}/do_{:08d}".format( a.checkpoint_path, steps) save_checkpoint( checkpoint_path, { 'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch }) # Tensorboard summary logging if steps % a.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", mel_error, steps) # Validation if steps % a.validation_interval == 0: # and steps != 0: generator.eval() torch.cuda.empty_cache() val_err_tot = 0 with torch.no_grad(): for j, batch in enumerate(validation_loader): x, y, _, y_mel = batch y_g_hat = generator(x.to(device)) y_mel = torch.autograd.Variable( y_mel.to(device, non_blocking=True)) y_g_hat_mel = mel_spectrogram( y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() if j <= 4: if steps == 0: sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) y_hat_spec = mel_spectrogram( y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) sw.add_figure( 'generated/y_hat_spec_{}'.format(j), plot_spectrogram( y_hat_spec.squeeze(0).cpu().numpy()), steps) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/mel_spec_error", val_err, steps) generator.train() steps += 1 scheduler_g.step() scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))
def generate(output_directory, tensorboard_directory, num_samples, ckpt_path, ckpt_iter): """ Generate audio based on ground truth mel spectrogram Parameters: output_directory (str): save generated speeches to this path tensorboard_directory (str): save tensorboard events to this path num_samples (int): number of samples to generate, default is 4 ckpt_path (str): checkpoint path ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; automitically selects the maximum iteration if 'max' is selected """ # generate experiment (local) path local_path = "ch{}_T{}_betaT{}".format(wavenet_config["res_channels"], diffusion_config["T"], diffusion_config["beta_T"]) # Get shared output_directory ready output_directory = os.path.join('exp', local_path, output_directory) if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) print("output directory", output_directory, flush=True) # map diffusion hyperparameters to gpu for key in diffusion_hyperparams: if key is not "T": diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda() # predefine model net = WaveNet(**wavenet_config).cuda() print_size(net) # load checkpoint ckpt_path = os.path.join('exp', local_path, ckpt_path) if ckpt_iter == 'max': ckpt_iter = find_max_epoch(ckpt_path) model_path = os.path.join(ckpt_path, '{}.pkl'.format(ckpt_iter)) try: checkpoint = torch.load(model_path, map_location='cpu') net.load_state_dict(checkpoint['model_state_dict']) print('Successfully loaded model at iteration {}'.format(ckpt_iter)) except: raise Exception('No valid model found') # predefine audio shape audio_length = trainset_config["segment_length"] # 16000 print('begin generating audio of length %s' % audio_length) # inference start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() generated_audio = sampling(net, (num_samples,1,audio_length), diffusion_hyperparams) end.record() torch.cuda.synchronize() print('generated {} utterances of random_digit at iteration {} in {} seconds'.format(num_samples, ckpt_iter, int(start.elapsed_time(end)/1000))) # save audio to .wav for i in range(num_samples): outfile = '{}_{}_{}k_{}.wav'.format(wavenet_config["res_channels"], diffusion_config["T"], ckpt_iter // 1000, i) wavwrite(os.path.join(output_directory, outfile), trainset_config["sampling_rate"], generated_audio[i].squeeze().cpu().numpy()) # save audio to tensorboard tb = SummaryWriter(os.path.join('exp', local_path, tensorboard_directory)) tb.add_audio(tag=outfile, snd_tensor=generated_audio[i], sample_rate=trainset_config["sampling_rate"]) tb.close() print('saved generated samples at iteration %s' % ckpt_iter)
class BaseTrainer: def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer): self.color_tool = colorful self.color_tool.use_style("solarized") model = DistributedDataParallel(model.to(rank), device_ids=[rank]) self.model = model self.optimizer = optimizer self.loss_function = loss_function # DistributedDataParallel (DDP) self.rank = rank self.dist = dist # Automatic mixed precision (AMP) self.use_amp = config["meta"]["use_amp"] self.scaler = GradScaler(enabled=self.use_amp) # Acoustics self.acoustic_config = config["acoustics"] # Supported STFT n_fft = self.acoustic_config["n_fft"] hop_length = self.acoustic_config["hop_length"] win_length = self.acoustic_config["win_length"] self.torch_stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.torch_istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length) # Trainer.train in the config self.train_config = config["trainer"]["train"] self.epochs = self.train_config["epochs"] self.save_checkpoint_interval = self.train_config[ "save_checkpoint_interval"] self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"] assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one." # Trainer.validation in the config self.validation_config = config["trainer"]["validation"] self.validation_interval = self.validation_config[ "validation_interval"] self.save_max_metric_score = self.validation_config[ "save_max_metric_score"] assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one." # Trainer.visualization in the config self.visualization_config = config["trainer"]["visualization"] # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args: self.start_epoch = 1 self.best_score = -np.inf if self.save_max_metric_score else np.inf self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute( ) / config["meta"]["experiment_name"] self.checkpoints_dir = self.save_dir / "checkpoints" self.logs_dir = self.save_dir / "logs" if resume: self._resume_checkpoint() # Debug validation, which skips training self.only_validation = only_validation if config["meta"]["preloaded_model_path"]: self._preload_model(Path(config["preloaded_model_path"])) if self.rank == 0: prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume) self.writer = SummaryWriter(self.logs_dir.as_posix(), max_queue=5, flush_secs=30) self.writer.add_text( tag="Configuration", text_string=f"<pre> \n{toml.dumps(config)} \n</pre>", global_step=1) print(self.color_tool.cyan("The configurations are as follows: ")) print(self.color_tool.cyan("=" * 40)) print(self.color_tool.cyan(toml.dumps(config)[:-1])) # except "\n" print(self.color_tool.cyan("=" * 40)) with open( (self.save_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle: toml.dump(config, handle) self._print_networks([self.model]) def _preload_model(self, model_path): """ Preload model parameters (in "*.tar" format) at the start of experiment. Args: model_path (Path): The file path of the *.tar file """ model_path = model_path.expanduser().absolute() assert model_path.exists( ), f"The file {model_path.as_posix()} is not exist. please check path." model_checkpoint = torch.load(model_path.as_posix(), map_location="cpu") self.model.load_state_dict(model_checkpoint["model"], strict=False) self.model.to(self.rank) if self.rank == 0: print( f"Model preloaded successfully from {model_path.as_posix()}.") def _resume_checkpoint(self): """ Resume the experiment from the latest checkpoint. """ latest_model_path = self.checkpoints_dir.expanduser().absolute( ) / "latest_model.tar" assert latest_model_path.exists( ), f"{latest_model_path} does not exist, can not load latest checkpoint." # Make sure all processes (GPUs) do not start loading before the saving is finished. # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work self.dist.barrier() # Load it on the CPU and later use .to(device) on the model # Maybe slightly slow than use map_location="cuda:<...>" # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion checkpoint = torch.load(latest_model_path.as_posix(), map_location="cpu") self.start_epoch = checkpoint["epoch"] + 1 self.best_score = checkpoint["best_score"] self.optimizer.load_state_dict(checkpoint["optimizer"]) self.scaler.load_state_dict(checkpoint["scaler"]) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model.module.load_state_dict(checkpoint["model"]) else: self.model.load_state_dict(checkpoint["model"]) # self.model.to(self.rank) if self.rank == 0: print( f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch." ) def _save_checkpoint(self, epoch, is_best_epoch=False): """ Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of: - epoch - best metric score in historical epochs - optimizer parameters - model parameters Args: is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True), the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar". """ print(f"\t Saving {epoch} epoch model checkpoint...") state_dict = { "epoch": epoch, "best_score": self.best_score, "optimizer": self.optimizer.state_dict(), "scaler": self.scaler.state_dict() } if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): state_dict["model"] = self.model.module.state_dict() else: state_dict["model"] = self.model.state_dict() # Saved in "latest_model.tar" # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc. # New checkpoint will overwrite the older one. torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix()) # "model_{epoch_number}.pth" # Contains only model. torch.save(state_dict["model"], (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth").as_posix()) # If the model get a best metric score (means "is_best_epoch=True") in the current epoch, # the model checkpoint will be saved as "best_model.tar" # The newer best-scored checkpoint will overwrite the older one. if is_best_epoch: print( self.color_tool.red( f"\t Found a best score in the {epoch} epoch, saving...")) torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix()) def _is_best_epoch(self, score, save_max_metric_score=True): """ Check if the current model got the best metric score """ if save_max_metric_score and score >= self.best_score: self.best_score = score return True elif not save_max_metric_score and score <= self.best_score: self.best_score = score return True else: return False @staticmethod def _print_networks(models: list): print( f"This project contains {len(models)} models, the number of the parameters is: " ) params_of_all_networks = 0 for idx, model in enumerate(models, start=1): params_of_network = 0 for param in model.parameters(): params_of_network += param.numel() print(f"\tNetwork {idx}: {params_of_network / 1e6} million.") params_of_all_networks += params_of_network print( f"The amount of parameters in the project is {params_of_all_networks / 1e6} million." ) def _set_models_to_train_mode(self): self.model.train() def _set_models_to_eval_mode(self): self.model.eval() def spec_audio_visualization(self, noisy, enhanced, clean, name, epoch, mark=""): self.writer.add_audio(f"{mark}_Speech/{name}_Noisy", noisy, epoch, sample_rate=16000) self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000) self.writer.add_audio(f"{mark}_Speech/{name}_Clean", clean, epoch, sample_rate=16000) # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech noisy_mag, _ = librosa.magphase( self.librosa_stft(noisy, n_fft=320, hop_length=160, win_length=320)) enhanced_mag, _ = librosa.magphase( self.librosa_stft(enhanced, n_fft=320, hop_length=160, win_length=320)) clean_mag, _ = librosa.magphase( self.librosa_stft(clean, n_fft=320, hop_length=160, win_length=320)) fig, axes = plt.subplots(3, 1, figsize=(6, 6)) for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]): axes[k].set_title(f"mean: {np.mean(mag):.3f}, " f"std: {np.std(mag):.3f}, " f"max: {np.max(mag):.3f}, " f"min: {np.min(mag):.3f}") librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000) plt.tight_layout() self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch) def metrics_visualization(self, noisy_list, clean_list, enhanced_list, metrics_list, epoch, num_workers=10, mark=""): """ Get metrics on validation dataset by paralleling. Notes: 1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are used for checking if the current epoch is a "best epoch." 2. If you want to use a new metric, you must register it in "util.metrics" file. """ assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence." # Check if the metric is registered in "util.metrics" file. for i in metrics_list: assert i in metrics.REGISTERED_METRICS.keys( ), f"{i} is not registered, please check 'util.metrics' file." stoi_mean = 0.0 wb_pesq_mean = 0.0 for metric_name in metrics_list: score_on_noisy = Parallel(n_jobs=num_workers)( delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, noisy_list)) score_on_enhanced = Parallel(n_jobs=num_workers)( delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, enhanced_list)) # Add the mean value of the metric to tensorboard mean_score_on_noisy = np.mean(score_on_noisy) mean_score_on_enhanced = np.mean(score_on_enhanced) self.writer.add_scalars(f"{mark}_Validation/{metric_name}", { "Noisy": mean_score_on_noisy, "Enhanced": mean_score_on_enhanced }, epoch) if metric_name == "STOI": stoi_mean = mean_score_on_enhanced if metric_name == "WB_PESQ": wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced) return (stoi_mean + wb_pesq_mean) / 2 def train(self): for epoch in range(self.start_epoch, self.epochs + 1): if self.rank == 0: print( self.color_tool.yellow( f"{'=' * 15} {epoch} epoch {'=' * 15}")) print("[0 seconds] Begin training...") # [debug validation] Only run validation (only use the first GPU (process)) # inference + calculating metrics + saving checkpoints if self.only_validation and self.rank == 0: self._set_models_to_eval_mode() metric_score = self._validation_epoch(epoch) if self._is_best_epoch( metric_score, save_max_metric_score=self.save_max_metric_score): self._save_checkpoint(epoch, is_best_epoch=True) # Skip the following regular training, saving checkpoints, and validation continue # Regular training timer = ExecutionTime() self._set_models_to_train_mode() self._train_epoch(epoch) # Regular save checkpoints if self.rank == 0 and self.save_checkpoint_interval != 0 and ( epoch % self.save_checkpoint_interval == 0): self._save_checkpoint(epoch) # Regular validation if self.rank == 0 and (epoch % self.validation_interval == 0): print( f"[{timer.duration()} seconds] Training has finished, validation is in progress..." ) self._set_models_to_eval_mode() metric_score = self._validation_epoch(epoch) if self._is_best_epoch( metric_score, save_max_metric_score=self.save_max_metric_score): self._save_checkpoint(epoch, is_best_epoch=True) print(f"[{timer.duration()} seconds] This epoch is finished.") def _train_epoch(self, epoch): raise NotImplementedError def _validation_epoch(self, epoch): raise NotImplementedError
def main(): args = parse_args() root = Path(args.save_path) load_root = Path(args.load_path) if args.load_path else None print(load_root) root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) writer = SummaryWriter(str(root)) ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels).cuda() fft = Audio2Mel(n_mel_channels=args.n_mel_channels, mel_fmin=40, mel_fmax=None, sampling_rate=22050).cuda() print(netG) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) if load_root and load_root.exists(): netG.load_state_dict(torch.load(load_root / "netG.pt")) optG.load_state_dict(torch.load(load_root / "optG.pt")) print('checkpoints loaded') ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=22050 ) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", ((22050*4//256)//32)*32*256, sampling_rate=22050, augment=False, ) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4, shuffle=True, pin_memory=True) test_loader = DataLoader(test_set, batch_size=1) mr_stft_loss = MultiResolutionSTFTLoss().cuda() ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] for i, x_t in enumerate(test_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() test_voc.append(s_t.cuda()) test_audio.append(x_t.cpu()) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), 22050, audio) writer.add_audio("original/sample_%d.wav" % i, audio, 0, sample_rate=22050) if i == args.n_test_samples - 1: break costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 steps = 0 for epoch in range(1, args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() n = torch.randn(x_t.shape[0], 128, 1).cuda() x_pred_t = netG(s_t.cuda(), n) ################### # Train Generator # ################### with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() sc_loss, mag_loss = mr_stft_loss(x_pred_t, x_t) loss_G = sc_loss + mag_loss netG.zero_grad() loss_G.backward() optG.step() ###################### # Update tensorboard # ###################### costs.append([loss_G.item(), sc_loss.item(), mag_loss.item(), s_error]) writer.add_scalar("loss/generator", costs[-1][0], steps) writer.add_scalar("loss/spectral_convergence", costs[-1][1], steps) writer.add_scalar("loss/log_spectrum", costs[-1][2], steps) writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): for i, (voc, _) in enumerate(zip(test_voc, test_audio)): n = torch.randn(1, 128, 10).cuda() pred_audio = netG(voc, n) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), 22050, pred_audio) writer.add_audio( "generated/sample_%d.wav" % i, pred_audio, epoch, sample_rate=22050, ) torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netG.state_dict(), root / "best_netG.pt") print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print( "Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), ) ) costs = [] start = time.time()