class Trainer: def __init__(self, path_state_dict=''): self.model = DeGLI(**hp.model) self.module = self.model self.criterion = nn.L1Loss(reduction='none') self.optimizer = Adam( self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) self.__init_device(hp.device, hp.out_device) self.scheduler = lr_scheduler.ReduceLROnPlateau( self.optimizer, **hp.scheduler) self.max_epochs = hp.n_epochs self.writer: Optional[CustomWriter] = None self.valid_eval_sample: Dict[str, Any] = dict() # if hp.model['final_avg']: # len_weight = hp.repeat_train # else: # len_weight = hp.model['depth'] * hp.repeat_train len_weight = hp.repeat_train self.loss_weight = torch.tensor( [1. / i for i in range(len_weight, 0, -1)], device=self.out_device, ) self.loss_weight /= self.loss_weight.sum() # Load State Dict if path_state_dict: st_model, st_optim, st_sched = torch.load( path_state_dict, map_location=self.in_device) try: self.module.load_state_dict(st_model) self.optimizer.load_state_dict(st_optim) self.scheduler.load_state_dict(st_sched) except: raise Exception('The model is different from the state dict.') path_summary = hp.logdir / 'summary.txt' if not path_summary.exists(): # print_to_file( # path_summary, # summary, # (self.model, hp.dummy_input_size), # dict(device=self.str_device[:4]) # ) with path_summary.open('w') as f: f.write('\n') with (hp.logdir / 'hparams.txt').open('w') as f: f.write(repr(hp)) def __init_device(self, device, out_device): """ :type device: Union[int, str, Sequence] :type out_device: Union[int, str, Sequence] :return: """ if device == 'cpu': self.in_device = torch.device('cpu') self.out_device = torch.device('cpu') self.str_device = 'cpu' return # device type: List[int] if type(device) == int: device = [device] elif type(device) == str: device = [int(device.replace('cuda:', ''))] else: # sequence of devices if type(device[0]) != int: device = [int(d.replace('cuda:', '')) for d in device] self.in_device = torch.device(f'cuda:{device[0]}') if len(device) > 1: if type(out_device) == int: self.out_device = torch.device(f'cuda:{out_device}') else: self.out_device = torch.device(out_device) self.str_device = ', '.join([f'cuda:{d}' for d in device]) self.model = nn.DataParallel(self.model, device_ids=device, output_device=self.out_device) else: self.out_device = self.in_device self.str_device = str(self.in_device) self.model.cuda(self.in_device) self.criterion.cuda(self.out_device) torch.cuda.set_device(self.in_device) def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: # B, F, T, C x = data['x'] mag = data['y_mag'] max_length = max(data['length']) y = data['y'] x = x.to(self.in_device, non_blocking=True) mag = mag.to(self.in_device, non_blocking=True) y = y.to(self.out_device, non_blocking=True) return x, mag, max_length, y @torch.no_grad() def postprocess(self, output: Tensor, residual: Tensor, Ts: ndarray, idx: int, dataset: ComplexSpecDataset) -> Dict[str, ndarray]: dict_one = dict(out=output, res=residual) for key in dict_one: one = dict_one[key][idx, :, :, :Ts[idx]] one = one.permute(1, 2, 0).contiguous() # F, T, 2 one = one.cpu().numpy().view(dtype=np.complex64) # F, T, 1 dict_one[key] = one return dict_one def calc_loss(self, out_blocks: Tensor, y: Tensor, T_ys: Sequence[int]) -> Tensor: """ out_blocks: B, depth, C, F, T y: B, C, F, T """ with warnings.catch_warnings(): warnings.simplefilter('ignore') loss_no_red = self.criterion(out_blocks, y.unsqueeze(1)) loss_blocks = torch.zeros(out_blocks.shape[1], device=y.device) for T, loss_batch in zip(T_ys, loss_no_red): loss_blocks += torch.mean(loss_batch[..., :T], dim=(1, 2, 3)) if len(loss_blocks) == 1: loss = loss_blocks.squeeze() else: loss = loss_blocks @ self.loss_weight return loss @torch.no_grad() def should_stop(self, loss_valid, epoch): if epoch == self.max_epochs - 1: return True self.scheduler.step(loss_valid) # if self.scheduler.t_epoch == 0: # if it is restarted now # # if self.loss_last_restart < loss_valid: # # return True # if self.loss_last_restart * hp.threshold_stop < loss_valid: # self.max_epochs = epoch + self.scheduler.restart_period + 1 # self.loss_last_restart = loss_valid def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # Start Training for epoch in range(first_epoch, hp.n_epochs): self.writer.add_scalar('loss/lr', self.optimizer.param_groups[0]['lr'], epoch) print() pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss = AverageMeter(float) avg_grad_norm = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, _, _ = self.model( x, mag, max_length, repeat=hp.repeat_train) # B, C, F, T loss = self.calc_loss(output_loss, y, T_ys) # backward self.optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), hp.thr_clip_grad) self.optimizer.step() # print avg_loss.update(loss.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') avg_grad_norm.update(grad_norm) self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch) self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch) # Validation # loss_valid = self.validate(loader_valid, logdir, epoch) loss_valid = self.validate(loader_valid, logdir, epoch, repeat=hp.repeat_train) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save(( self.module.state_dict(), self.optimizer.state_dict(), self.scheduler.state_dict(), ), logdir / f'{epoch}.pt') # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close() @torch.no_grad() def validate(self, loader: DataLoader, logdir: Path, epoch: int, repeat=1): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ suffix = f'_{repeat}' if repeat > 1 else '' self.model.eval() avg_loss = AverageMeter(float) pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, output, residual = self.model(x, mag, max_length, repeat=repeat) # loss loss = self.calc_loss(output_loss, y, T_ys) avg_loss.update(loss.item(), len(T_ys)) # print pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') # write summary if i_iter == 0: # F, T, C if not self.valid_eval_sample: self.valid_eval_sample = ComplexSpecDataset.decollate_padded( data, 0) out_one = self.postprocess(output, residual, T_ys, 0, loader.dataset) # ComplexSpecDataset.save_dirspec( # logdir / hp.form_result.format(epoch), # **self.valid_eval_sample, **out_one # ) if not self.writer.reused_sample: one_sample = self.valid_eval_sample else: one_sample = dict() self.writer.write_one(epoch, **one_sample, **out_one, suffix=suffix) self.writer.add_scalar(f'loss/valid{suffix}', avg_loss.get_average(), epoch) self.model.train() return avg_loss.get_average() @torch.no_grad() def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer for sub in self.module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) pbar = tqdm(loader, desc=group, dynamic_ncols=True) cnt_sample = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) if 0 < hp.n_save_block_outs == i_iter: break _, output, residual = self.model(x, mag, max_length, repeat=hp.repeat_test) # write summary for i_b in range(len(T_ys)): i_sample = cnt_sample + i_b one_sample = ComplexSpecDataset.decollate_padded( data, i_b) # F, T, C out_one = self.postprocess(output, residual, T_ys, i_b, loader.dataset) ComplexSpecDataset.save_dirspec( logdir / hp.form_result.format(i_sample), **one_sample, **out_one) measure = self.writer.write_one(i_sample, **out_one, **one_sample, suffix=f'_{hp.repeat_test}') if avg_measure is None: avg_measure = AverageMeter(init_value=measure) else: avg_measure.update(measure) # print # str_measure = arr2str(measure).replace('\n', '; ') # pbar.write(str_measure) cnt_sample += len(T_ys) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text(f'{group}/Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text(f'{group}/Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}')
def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer for sub in self.module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) pbar = tqdm(loader, desc=group, dynamic_ncols=True) cnt_sample = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) if 0 < hp.n_save_block_outs == i_iter: break _, output, residual = self.model(x, mag, max_length, repeat=hp.repeat_test) # write summary for i_b in range(len(T_ys)): i_sample = cnt_sample + i_b one_sample = ComplexSpecDataset.decollate_padded( data, i_b) # F, T, C out_one = self.postprocess(output, residual, T_ys, i_b, loader.dataset) ComplexSpecDataset.save_dirspec( logdir / hp.form_result.format(i_sample), **one_sample, **out_one) measure = self.writer.write_one(i_sample, **out_one, **one_sample, suffix=f'_{hp.repeat_test}') if avg_measure is None: avg_measure = AverageMeter(init_value=measure) else: avg_measure.update(measure) # print # str_measure = arr2str(measure).replace('\n', '; ') # pbar.write(str_measure) cnt_sample += len(T_ys) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text(f'{group}/Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text(f'{group}/Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}')
class Trainer: def __init__(self, path_state_dict=''): self.model_name = hp.model_name module = eval(hp.model_name) self.model = module(**getattr(hp, hp.model_name)) self.criterion = nn.MSELoss(reduction='none') self.optimizer = AdamW( self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) self.__init_device(hp.device, hp.out_device) self.scheduler: Optional[CosineLRWithRestarts] = None self.max_epochs = hp.n_epochs self.loss_last_restart = float('inf') self.writer: Optional[CustomWriter] = None # a sample in validation set for evaluation self.valid_eval_sample: Dict[str, Any] = dict() # Load State Dict if path_state_dict: st_model, st_optim = torch.load(path_state_dict, map_location=self.in_device) try: if hasattr(self.model, 'module'): self.model.module.load_state_dict(st_model) else: self.model.load_state_dict(st_model) self.optimizer.load_state_dict(st_optim) except: raise Exception('The model is different from the state dict.') path_summary = hp.logdir / 'summary.txt' if not path_summary.exists(): print_to_file(path_summary, summary, (self.model, hp.dummy_input_size), dict(device=self.str_device[:4])) with (hp.logdir / 'hparams.txt').open('w') as f: f.write(repr(hp)) def __init_device(self, device, out_device): """ :type device: Union[int, str, Sequence] :type out_device: Union[int, str, Sequence] :return: """ if device == 'cpu': self.in_device = torch.device('cpu') self.out_device = torch.device('cpu') self.str_device = 'cpu' return # device type: List[int] if type(device) == int: device = [device] elif type(device) == str: device = [int(device.replace('cuda:', ''))] else: # sequence of devices if type(device[0]) != int: device = [int(d.replace('cuda:', '')) for d in device] self.in_device = torch.device(f'cuda:{device[0]}') if len(device) > 1: if type(out_device) == int: self.out_device = torch.device(f'cuda:{out_device}') else: self.out_device = torch.device(out_device) self.str_device = ', '.join([f'cuda:{d}' for d in device]) self.model = nn.DataParallel(self.model, device_ids=device, output_device=self.out_device) else: self.out_device = self.in_device self.str_device = str(self.in_device) self.model.cuda(self.in_device) self.criterion.cuda(self.out_device) torch.cuda.set_device(self.in_device) def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: # B, F, T, C x = data['normalized_x'] y = data['normalized_y'] x = x.to(self.in_device, non_blocking=True) y = y.to(self.out_device, non_blocking=True) return x, y @torch.no_grad() def postprocess(self, output: Tensor, Ts: ndarray, idx: int, dataset: DirSpecDataset) -> Dict[str, ndarray]: one = output[idx, :, :, :Ts[idx]] if self.model_name.startswith('UNet'): one = one.permute(1, 2, 0) # F, T, C one = dataset.denormalize_(y=one) one = one.cpu().numpy() return dict(out=one) def calc_loss(self, output: Tensor, y: Tensor, T_ys: Sequence[int]) -> Tensor: loss_batch = self.criterion(output, y) loss = torch.zeros(1, device=loss_batch.device) for T, loss_sample in zip(T_ys, loss_batch): loss += torch.sum(loss_sample[:, :, :T]) / T return loss @torch.no_grad() def should_stop(self, loss_valid, epoch): if epoch == self.max_epochs - 1: return True self.scheduler.step() # early stopping criterion # if self.scheduler.t_epoch == 0: # if it is restarted now # # if self.loss_last_restart < loss_valid: # # return True # if self.loss_last_restart * hp.threshold_stop < loss_valid: # self.max_epochs = epoch + self.scheduler.restart_period + 1 # self.loss_last_restart = loss_valid def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): # Learning Rate Scheduler self.scheduler = CosineLRWithRestarts(self.optimizer, batch_size=hp.batch_size, epoch_size=len( loader_train.dataset), last_epoch=first_epoch - 1, **hp.scheduler) self.scheduler.step() self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # write DNN structure to tensorboard. not properly work in PyTorch 1.3 # self.writer.add_graph( # self.model.module if hasattr(self.model, 'module') else self.model, # torch.zeros(1, hp.UNet['ch_in'], 256, 256, device=self.in_device), # ) # Start Training for epoch in range(first_epoch, hp.n_epochs): print() pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output = self.model(x)[..., :y.shape[-1]] # B, C, F, T loss = self.calc_loss(output, y, T_ys) # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.batch_step() # print avg_loss.update(loss.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch) # Validation loss_valid = self.validate(loader_valid, logdir, epoch) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save(( self.model.module.state_dict(), self.optimizer.state_dict(), ), logdir / f'{epoch}.pt') # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close() @torch.no_grad() def validate(self, loader: DataLoader, logdir: Path, epoch: int): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ self.model.eval() avg_loss = AverageMeter(float) pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output = self.model(x)[..., :y.shape[-1]] # loss loss = self.calc_loss(output, y, T_ys) avg_loss.update(loss.item(), len(T_ys)) # print pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') # write summary if i_iter == 0: # F, T, C if epoch == 0: one_sample = DirSpecDataset.decollate_padded(data, 0) else: one_sample = dict() out_one = self.postprocess(output, T_ys, 0, loader.dataset) # DirSpecDataset.save_dirspec( # logdir / hp.form_result.format(epoch), # **one_sample, **out_one # ) self.writer.write_one(epoch, **one_sample, **out_one) self.writer.add_scalar('loss/valid', avg_loss.get_average(), epoch) self.model.train() return avg_loss.get_average() @torch.no_grad() def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): """ save forward propagation data """ module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() # register hook to save output of each block module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer if isinstance(self.model, nn.DataParallel): module = self.model.module else: module = self.model for sub in module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) pbar = tqdm(loader, desc=group, dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) if 0 < hp.n_save_block_outs == i_iter: break output = self.model(x) # [..., :y.shape[-1]] # write summary one_sample = DirSpecDataset.decollate_padded(data, 0) # F, T, C out_one = self.postprocess(output, T_ys, 0, loader.dataset) # DirSpecDataset.save_dirspec( # logdir / hp.form_result.format(i_iter), # **one_sample, **out_one # ) measure = self.writer.write_one( i_iter, eval_with_y_ph=hp.eval_with_y_ph, **out_one, **one_sample, ) if avg_measure is None: avg_measure = AverageMeter(init_value=measure, init_count=len(T_ys)) else: avg_measure.update(measure) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text('Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text('Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}') @torch.no_grad() def save_result(self, loader: DataLoader, logdir: Path): """ save results in npz files without evaluation for deep griffin-lim algorithm :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ import numpy as np self.model.eval() # avg_loss = AverageMeter(float) pbar = tqdm(loader, desc='save ', dynamic_ncols=True) i_cum = 0 for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output = self.model(x)[..., :y.shape[-1]] # B, C, F, T output = output.permute(0, 2, 3, 1) # B, F, T, C out_denorm = loader.dataset.denormalize_(y=output).cpu().numpy() np.maximum(out_denorm, 0, out=out_denorm) out_denorm = out_denorm.squeeze() # B, F, T # B, F, T x_phase = data['x_phase'][..., :y.shape[-1], 0].numpy() y_phase = data['y_phase'].numpy().squeeze() out_x_ph = out_denorm * np.exp(1j * x_phase) out_y_ph = out_denorm * np.exp(1j * y_phase) for i_b, T, in enumerate(T_ys): # F, T noisy = np.ascontiguousarray(out_x_ph[i_b, ..., :T]) clean = np.ascontiguousarray(out_y_ph[i_b, ..., :T]) mag = np.ascontiguousarray(out_denorm[i_b, ..., :T]) length = hp.n_fft + hp.l_hop * (T - 1) - hp.n_fft // 2 * 2 spec_data = dict(spec_noisy=noisy, spec_clean=clean, mag_clean=mag, length=length) np.savez(str(logdir / f'{i_cum + i_b}.npz'), **spec_data) i_cum += len(T_ys) self.model.train()
def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # Start Training for epoch in range(first_epoch, hp.n_epochs): self.writer.add_scalar('loss/lr', self.optimizer.param_groups[0]['lr'], epoch) print() pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss = AverageMeter(float) avg_grad_norm = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, _, _ = self.model( x, mag, max_length, repeat=hp.repeat_train) # B, C, F, T loss = self.calc_loss(output_loss, y, T_ys) # backward self.optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), hp.thr_clip_grad) self.optimizer.step() # print avg_loss.update(loss.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') avg_grad_norm.update(grad_norm) self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch) self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch) # Validation # loss_valid = self.validate(loader_valid, logdir, epoch) loss_valid = self.validate(loader_valid, logdir, epoch, repeat=hp.repeat_train) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save(( self.module.state_dict(), self.optimizer.state_dict(), self.scheduler.state_dict(), ), logdir / f'{epoch}.pt') # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close()
def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): # Learning Rate Scheduler self.scheduler = CosineLRWithRestarts(self.optimizer, batch_size=hp.batch_size, epoch_size=len( loader_train.dataset), last_epoch=first_epoch - 1, **hp.scheduler) self.scheduler.step() self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # write DNN structure to tensorboard. not properly work in PyTorch 1.3 # self.writer.add_graph( # self.model.module if hasattr(self.model, 'module') else self.model, # torch.zeros(1, hp.UNet['ch_in'], 256, 256, device=self.in_device), # ) # Start Training for epoch in range(first_epoch, hp.n_epochs): print() pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output = self.model(x)[..., :y.shape[-1]] # B, C, F, T loss = self.calc_loss(output, y, T_ys) # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.batch_step() # print avg_loss.update(loss.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch) # Validation loss_valid = self.validate(loader_valid, logdir, epoch) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save(( self.model.module.state_dict(), self.optimizer.state_dict(), ), logdir / f'{epoch}.pt') # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close()
def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): """ save forward propagation data """ module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() # register hook to save output of each block module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer if isinstance(self.model, nn.DataParallel): module = self.model.module else: module = self.model for sub in module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) pbar = tqdm(loader, desc=group, dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) if 0 < hp.n_save_block_outs == i_iter: break output = self.model(x) # [..., :y.shape[-1]] # write summary one_sample = DirSpecDataset.decollate_padded(data, 0) # F, T, C out_one = self.postprocess(output, T_ys, 0, loader.dataset) # DirSpecDataset.save_dirspec( # logdir / hp.form_result.format(i_iter), # **one_sample, **out_one # ) measure = self.writer.write_one( i_iter, eval_with_y_ph=hp.eval_with_y_ph, **out_one, **one_sample, ) if avg_measure is None: avg_measure = AverageMeter(init_value=measure, init_count=len(T_ys)) else: avg_measure.update(measure) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text('Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text('Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}')
def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] if self.writer is None: self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() depth = hp.model['depth'] module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer for sub in self.module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) ##pbar = tqdm(loader, desc=group, dynamic_ncols=True) cnt_sample = 0 for i_iter, data in enumerate(loader): sampleDict = {} # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T if hp.noisy_init: x = torch.normal(0, 1, x.shape).cuda(self.in_device) T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) # if 0 < hp.n_save_block_outs == i_iter: # break repeats = 1 for _ in range(3): _, output, residual = self.model(x, mag, max_length, repeat=1, train_step=1) ##warn up! _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) while repeats <= hp.repeat_test: stime = ms() _, output, residual = self.model(x, mag, max_length, repeat=repeats, train_step=1) avg_measure = AverageMeter() avg_measure2 = AverageMeter() etime = ms(stime) speed = (max_length / hp.fs) * len(T_ys) / (etime / 1000) ##print("degli: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed)) ##self.writer.add_scalar("Test Performance/degli", speed, repeats) # write summary for i_b in tqdm(range(len(T_ys)), desc="degli, %d repeats" % repeats, dynamic_ncols=True): i_sample = cnt_sample + i_b if not i_b in sampleDict: one_sample = ComplexSpecDataset.decollate_padded( data, i_b) reused_sample, result_eval_glim = self.writer.write_zero( 0, i_b, **one_sample, suffix="Base stats") sampleDict[i_b] = (reused_sample, result_eval_glim) sampleItem = sampleDict[i_b] reused_sample = sampleItem[0] result_eval_glim = sampleItem[1] out_one = self.postprocess(output, residual, T_ys, i_b, loader.dataset) # ComplexSpecDataset.save_dirspec( # logdir / hp.form_result.format(i_sample), # **one_sample, **out_one # ) measure = self.writer.write_one(repeats, i_b, result_eval_glim, reused_sample, **out_one, suffix="3_deGLI") avg_measure.update(measure) stime = ms() _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) etime = ms(stime) speed = (1000 * max_length / hp.fs) * len(T_ys) / (etime) ##print("pure gla: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed)) ##self.writer.add_scalar("Test Performance/gla", speed, repeats) # write summary for i_b in tqdm(range(len(T_ys)), desc="GLA, %d repeats" % repeats, dynamic_ncols=True): i_sample = cnt_sample + i_b sampleItem = sampleDict[i_b] reused_sample = sampleItem[0] result_eval_glim = sampleItem[1] out_one = self.postprocess(output, None, T_ys, i_b, loader.dataset) measure = self.writer.write_one(repeats, i_b, result_eval_glim, reused_sample, **out_one, suffix="4_GLA") avg_measure2.update(measure) cnt_sample += len(T_ys) self.writer.add_scalar(f'STOI/Average Measure/deGLI', avg_measure.get_average()[0, 0], repeats * depth) self.writer.add_scalar(f'STOI/Average Measure/GLA', avg_measure2.get_average()[0, 0], repeats * depth) self.writer.add_scalar(f'STOI/Average Measure/deGLI_semilogx', avg_measure.get_average()[0, 0], int(repeats * depth).bit_length()) self.writer.add_scalar(f'STOI/Average Measure/GLA_semilogx', avg_measure2.get_average()[0, 0], int(repeats * depth).bit_length()) self.writer.add_scalar(f'PESQ/Average Measure/deGLI', avg_measure.get_average()[0, 1], repeats * depth) self.writer.add_scalar(f'PESQ/Average Measure/GLA', avg_measure2.get_average()[0, 1], repeats * depth) self.writer.add_scalar(f'PESQ/Average Measure/deGLI_semilogx', avg_measure.get_average()[0, 1], int(repeats * depth).bit_length()) self.writer.add_scalar(f'PESQ/Average Measure/GLA_semilogx', avg_measure2.get_average()[0, 1], int(repeats * depth).bit_length()) repeats = repeats * 2 break self.model.train() self.writer.close() # Explicitly close
def speedtest(self, loader: DataLoader, logdir: Path): group = logdir.name.split('_')[0] if self.writer is None: self.writer = CustomWriter(str(logdir), group=group) depth = hp.model['depth'] ##pbar = tqdm(loader, desc=group, dynamic_ncols=True) repeats = 1 while repeats * depth <= hp.repeat_test: pbar = tqdm(loader, desc="degli performance, %d repeats" % repeats, dynamic_ncols=True) stime = time() tot_len = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T _, output, residual = self.model(x, mag, max_length, repeat=repeats, train_step=1) tot_len = tot_len + max_length * x.size(0) etime = int(time() - stime) speed = (tot_len / hp.sampling_rate) / (etime) self.writer.add_scalar("Test Performance/degli", speed, repeats * depth) self.writer.add_scalar("Test Performance/degli_semilogx", speed, int(repeats * depth).bit_length()) repeats = repeats * 2 repeats = 1 while repeats * depth <= hp.repeat_test: stime = time() pbar = tqdm(loader, desc="GLA performance, %d repeats" % repeats, dynamic_ncols=True) tot_len = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) tot_len = tot_len + max_length * x.size(0) etime = int(time() - stime) speed = (tot_len / hp.sampling_rate) / (etime) self.writer.add_scalar("Test Performance/gla", speed, repeats * depth) self.writer.add_scalar("Test Performance/gla_semilogx", speed, int(repeats * depth).bit_length()) repeats = repeats * 2 self.model.train() self.writer.close() # Explicitly close
class Trainer: def __init__(self, path_state_dict=''): self.writer: Optional[CustomWriter] = None config = { 'vanilla': hp.vanilla_model, "ed": hp.ed_model }[hp.model_type.lower()] self.model = DeGLI(self.writer, config, hp.model_type, hp.n_freq, hp.use_fp16, **hp.model) count_parameters(self.model) self.criterion = nn.L1Loss(reduction='none') if hp.optimizer == "adam": self.optimizer = Adam( self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) elif hp.optimizer == "sgd": self.optimizer = SGD( self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) elif hp.optimizer == "radam": self.optimizer = RAdam( self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) elif hp.optimizer == "novograd": self.optimizer = NovoGrad(self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay) elif hp.optimizer == "sm3": raise NameError('sm3 not implemented') else: raise NameError('optimizer not implemented') self.module = self.model # self.optimizer = SGD(self.model.parameters(), # lr=hp.learning_rate, # weight_decay=hp.weight_decay, # ) self.__init_device(hp.device, hp.out_device) if hp.use_fp16: from apex import amp self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') self.reused_sample = None self.result_eval_glim = None ##if hp.optimizer == "novograd": ## self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, 744*3 ,1e-4) ##else: self.scheduler = lr_scheduler.ReduceLROnPlateau( self.optimizer, **hp.scheduler) self.max_epochs = hp.n_epochs self.valid_eval_sample: Dict[str, Any] = dict() # if hp.model['final_avg']: # len_weight = hp.repeat_train # else: # len_weight = hp.model['depth'] * hp.repeat_train len_weight = hp.repeat_train self.loss_weight = torch.tensor( [1. / i for i in range(len_weight, 0, -1)], device=self.out_device, ) self.loss_weight /= self.loss_weight.sum() # Load State Dict if path_state_dict: st_model, st_optim, st_sched = torch.load( path_state_dict, map_location=self.in_device) try: self.module.load_state_dict(st_model) self.optimizer.load_state_dict(st_optim) self.scheduler.load_state_dict(st_sched) except: raise Exception('The model is different from the state dict.') path_summary = hp.logdir / 'summary.txt' if not path_summary.exists(): # print_to_file( # path_summary, # summary, # (self.model, hp.dummy_input_size), # dict(device=self.str_device[:4]) # ) with path_summary.open('w') as f: f.write('\n') with (hp.logdir / 'hparams.txt').open('w') as f: f.write(repr(hp)) def __init_device(self, device, out_device): """ :type device: Union[int, str, Sequence] :type out_device: Union[int, str, Sequence] :return: """ if device == 'cpu': self.in_device = torch.device('cpu') self.out_device = torch.device('cpu') self.str_device = 'cpu' return # device type: List[int] if type(device) == int: device = [device] elif type(device) == str: if device[0] == 'a': device = [x for x in range(torch.cuda.device_count())] else: device = [ int(d.replace('cuda:', '')) for d in device.split(",") ] print("Used devices = %s" % device) else: # sequence of devices if type(device[0]) != int: device = [int(d.replace('cuda:', '')) for d in device] self.in_device = torch.device(f'cuda:{device[0]}') if len(device) > 1: if type(out_device) == int: self.out_device = torch.device(f'cuda:{out_device}') else: self.out_device = torch.device(out_device) self.out_device = 0 self.str_device = ', '.join([f'cuda:{d}' for d in device]) self.model = nn.DataParallel(self.model, device_ids=device, output_device=self.out_device) else: self.out_device = self.in_device self.str_device = str(self.in_device) self.model.cuda(self.in_device) self.criterion.cuda(self.out_device) torch.cuda.set_device(self.in_device) def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: # B, F, T, C x = data['x'] mag = data['y_mag'] max_length = max(data['length']) y = data['y'] x = x.to(self.in_device, non_blocking=True) mag = mag.to(self.in_device, non_blocking=True) y = y.to(self.out_device, non_blocking=True) return x, mag, max_length, y @torch.no_grad() def postprocess(self, output: Tensor, residual: Tensor, Ts: ndarray, idx: int, dataset: ComplexSpecDataset) -> Dict[str, ndarray]: dict_one = dict(out=output, res=residual) for key in dict_one: if dict_one[key] is None: continue one = dict_one[key][idx, :, :, :Ts[idx]] one = one.permute(1, 2, 0).contiguous() # F, T, 2 one = one.cpu().numpy().view(dtype=np.complex64) # F, T, 1 dict_one[key] = one return dict_one def calc_loss(self, out_blocks: Tensor, y: Tensor, T_ys: Sequence[int]) -> Tensor: """ out_blocks: B, depth, C, F, T y: B, C, F, T """ with warnings.catch_warnings(): warnings.simplefilter('ignore') loss_no_red = self.criterion(out_blocks, y.unsqueeze(1)) loss_blocks = torch.zeros(out_blocks.shape[1], device=y.device) for T, loss_batch in zip(T_ys, loss_no_red): loss_blocks += torch.mean(loss_batch[..., :T], dim=(1, 2, 3)) if len(loss_blocks) == 1: loss = loss_blocks.squeeze() else: loss = loss_blocks @ self.loss_weight return loss @torch.no_grad() def should_stop(self, loss_valid, epoch): if epoch == self.max_epochs - 1: return True self.scheduler.step(loss_valid) # if self.scheduler.t_epoch == 0: # if it is restarted now # # if self.loss_last_restart < loss_valid: # # return True # if self.loss_last_restart * hp.threshold_stop < loss_valid: # self.max_epochs = epoch + self.scheduler.restart_period + 1 # self.loss_last_restart = loss_valid def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # Start Training step = 0 loss_valid = self.validate(loader_valid, logdir, 0, step, repeat=hp.repeat_train) for epoch in range(first_epoch, hp.n_epochs): self.writer.add_scalar('loss/lr', self.optimizer.param_groups[0]['lr'], epoch) pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss = AverageMeter(float) avg_grad_norm = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, _, _ = self.model(x, mag, max_length, repeat=hp.repeat_train, train_step=step) # B, C, F, T step = step + 1 loss = self.calc_loss(output_loss, y, T_ys) # backward self.optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), hp.thr_clip_grad) self.optimizer.step() # print # if np.any(np.isnan(loss.item())): # raise NameError('Loss is Nan!') # for vname,var in self.model.named_parameters(): # if np.any(np.isnan(var.detach().cpu().numpy())): # print("nan detected in %s " % vname) ##import pdb; pdb.set_trace() avg_loss.update(loss.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') avg_grad_norm.update(grad_norm) if i_iter % 25 == 0: self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch * len(loader_train) + i_iter) self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch * len(loader_train) + i_iter) avg_loss = AverageMeter(float) avg_grad_norm = AverageMeter(float) # Validation # loss_valid = self.validate(loader_valid, logdir, epoch) loss_valid = self.validate(loader_valid, logdir, epoch + 1, step, repeat=hp.repeat_train) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save(( self.module.state_dict(), self.optimizer.state_dict(), self.scheduler.state_dict(), ), logdir / f'{epoch+1}.pt') # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close() @torch.no_grad() def validate(self, loader: DataLoader, logdir: Path, epoch: int, step, repeat=1): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ suffix = f'_{repeat}' if repeat > 1 else '' self.model.eval() stoi_cnt = 0 stoi_cntX = 0 stoi_iters = hp.stoi_iters stoi_iters_rate = hp.stoi_iters_rate avg_loss = AverageMeter(float) avg_measure = AverageMeter(float) pesq_avg_measure = AverageMeter(float) avg_measureX = AverageMeter(float) pesq_avg_measureX = AverageMeter(float) pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True) num_iters = len(pbar) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, output, residual = self.model(x, mag, max_length, repeat=repeat, train_step=step) # loss loss = self.calc_loss(output_loss, y, T_ys) avg_loss.update(loss.item(), len(T_ys)) # print pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') # write summary # if i_iter == 0: # if self.reused_sample is None: # one_sample = ComplexSpecDataset.decollate_padded(data, i_iter) # self.reused_sample, self.result_eval_glim = self.writer.write_zero(0, i_iter, **one_sample, suffix="Base stats") # out_one = self.postprocess(output, residual, T_ys, i_iter, loader.dataset) # self.writer.write_one(0, i_iter, self.result_eval_glim, self.reused_sample ,**out_one, suffix="deGLI") if stoi_cnt <= hp.num_stoi: ##import pdb; pdb.set_trace() for p in range(min(hp.num_stoi // num_iters, len(T_ys))): y_wav = data['wav'][p] out = self.postprocess(output, None, T_ys, p, None)['out'] out_wav = reconstruct_wave(out, n_sample=data['length'][p]) measure = calc_using_eval_module(y_wav, out_wav) stoi = measure['STOI'] pesq_score = measure['PESQ'] avg_measure.update(stoi) pesq_avg_measure.update(pesq_score) stoi_cnt = stoi_cnt + 1 if (stoi_iters > 0) and (epoch % stoi_iters_rate == 0): _, output, _ = self.model(x, mag, max_length, repeat=stoi_iters, train_step=step) if stoi_cntX <= hp.num_stoi: ##import pdb; pdb.set_trace() for p in range(min(hp.num_stoi // num_iters, len(T_ys))): y_wav = data['wav'][p] out = self.postprocess(output, None, T_ys, p, None)['out'] out_wav = reconstruct_wave(out, n_sample=data['length'][p]) measure = calc_using_eval_module(y_wav, out_wav) stoi = measure['STOI'] pesq_score = measure['PESQ'] avg_measureX.update(stoi) pesq_avg_measureX.update(pesq_score) stoi_cntX = stoi_cntX + 1 self.writer.add_scalar(f'loss/valid', avg_loss.get_average(), epoch) self.writer.add_scalar(f'loss/STOI', avg_measure.get_average(), epoch) self.writer.add_scalar(f'loss/PESQ', pesq_avg_measure.get_average(), epoch) if (stoi_iters > 0) and (epoch % stoi_iters_rate == 0): self.writer.add_scalar(f'loss/PESQ_X{stoi_iters}', pesq_avg_measureX.get_average(), epoch) self.writer.add_scalar(f'loss/STOI_X{stoi_iters}', avg_measureX.get_average(), epoch) self.model.train() return avg_loss.get_average() @torch.no_grad() def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] if self.writer is None: self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() depth = hp.model['depth'] module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer for sub in self.module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) ##pbar = tqdm(loader, desc=group, dynamic_ncols=True) cnt_sample = 0 for i_iter, data in enumerate(loader): sampleDict = {} # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T if hp.noisy_init: x = torch.normal(0, 1, x.shape).cuda(self.in_device) T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) # if 0 < hp.n_save_block_outs == i_iter: # break repeats = 1 for _ in range(3): _, output, residual = self.model(x, mag, max_length, repeat=1, train_step=1) ##warn up! _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) while repeats <= hp.repeat_test: stime = ms() _, output, residual = self.model(x, mag, max_length, repeat=repeats, train_step=1) avg_measure = AverageMeter() avg_measure2 = AverageMeter() etime = ms(stime) speed = (max_length / hp.fs) * len(T_ys) / (etime / 1000) ##print("degli: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed)) ##self.writer.add_scalar("Test Performance/degli", speed, repeats) # write summary for i_b in tqdm(range(len(T_ys)), desc="degli, %d repeats" % repeats, dynamic_ncols=True): i_sample = cnt_sample + i_b if not i_b in sampleDict: one_sample = ComplexSpecDataset.decollate_padded( data, i_b) reused_sample, result_eval_glim = self.writer.write_zero( 0, i_b, **one_sample, suffix="Base stats") sampleDict[i_b] = (reused_sample, result_eval_glim) sampleItem = sampleDict[i_b] reused_sample = sampleItem[0] result_eval_glim = sampleItem[1] out_one = self.postprocess(output, residual, T_ys, i_b, loader.dataset) # ComplexSpecDataset.save_dirspec( # logdir / hp.form_result.format(i_sample), # **one_sample, **out_one # ) measure = self.writer.write_one(repeats, i_b, result_eval_glim, reused_sample, **out_one, suffix="3_deGLI") avg_measure.update(measure) stime = ms() _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) etime = ms(stime) speed = (1000 * max_length / hp.fs) * len(T_ys) / (etime) ##print("pure gla: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed)) ##self.writer.add_scalar("Test Performance/gla", speed, repeats) # write summary for i_b in tqdm(range(len(T_ys)), desc="GLA, %d repeats" % repeats, dynamic_ncols=True): i_sample = cnt_sample + i_b sampleItem = sampleDict[i_b] reused_sample = sampleItem[0] result_eval_glim = sampleItem[1] out_one = self.postprocess(output, None, T_ys, i_b, loader.dataset) measure = self.writer.write_one(repeats, i_b, result_eval_glim, reused_sample, **out_one, suffix="4_GLA") avg_measure2.update(measure) cnt_sample += len(T_ys) self.writer.add_scalar(f'STOI/Average Measure/deGLI', avg_measure.get_average()[0, 0], repeats * depth) self.writer.add_scalar(f'STOI/Average Measure/GLA', avg_measure2.get_average()[0, 0], repeats * depth) self.writer.add_scalar(f'STOI/Average Measure/deGLI_semilogx', avg_measure.get_average()[0, 0], int(repeats * depth).bit_length()) self.writer.add_scalar(f'STOI/Average Measure/GLA_semilogx', avg_measure2.get_average()[0, 0], int(repeats * depth).bit_length()) self.writer.add_scalar(f'PESQ/Average Measure/deGLI', avg_measure.get_average()[0, 1], repeats * depth) self.writer.add_scalar(f'PESQ/Average Measure/GLA', avg_measure2.get_average()[0, 1], repeats * depth) self.writer.add_scalar(f'PESQ/Average Measure/deGLI_semilogx', avg_measure.get_average()[0, 1], int(repeats * depth).bit_length()) self.writer.add_scalar(f'PESQ/Average Measure/GLA_semilogx', avg_measure2.get_average()[0, 1], int(repeats * depth).bit_length()) repeats = repeats * 2 break self.model.train() self.writer.close() # Explicitly close ##print() ##str_avg_measure = arr2str(avg_measure).replace('\n', '; ') ##print(f'Average: {str_avg_measure}') @torch.no_grad() def speedtest(self, loader: DataLoader, logdir: Path): group = logdir.name.split('_')[0] if self.writer is None: self.writer = CustomWriter(str(logdir), group=group) depth = hp.model['depth'] ##pbar = tqdm(loader, desc=group, dynamic_ncols=True) repeats = 1 while repeats * depth <= hp.repeat_test: pbar = tqdm(loader, desc="degli performance, %d repeats" % repeats, dynamic_ncols=True) stime = time() tot_len = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T _, output, residual = self.model(x, mag, max_length, repeat=repeats, train_step=1) tot_len = tot_len + max_length * x.size(0) etime = int(time() - stime) speed = (tot_len / hp.sampling_rate) / (etime) self.writer.add_scalar("Test Performance/degli", speed, repeats * depth) self.writer.add_scalar("Test Performance/degli_semilogx", speed, int(repeats * depth).bit_length()) repeats = repeats * 2 repeats = 1 while repeats * depth <= hp.repeat_test: stime = time() pbar = tqdm(loader, desc="GLA performance, %d repeats" % repeats, dynamic_ncols=True) tot_len = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) tot_len = tot_len + max_length * x.size(0) etime = int(time() - stime) speed = (tot_len / hp.sampling_rate) / (etime) self.writer.add_scalar("Test Performance/gla", speed, repeats * depth) self.writer.add_scalar("Test Performance/gla_semilogx", speed, int(repeats * depth).bit_length()) repeats = repeats * 2 self.model.train() self.writer.close() # Explicitly close
def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # Start Training step = 0 loss_valid = self.validate(loader_valid, logdir, 0, step, repeat=hp.repeat_train) for epoch in range(first_epoch, hp.n_epochs): self.writer.add_scalar('loss/lr', self.optimizer.param_groups[0]['lr'], epoch) pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss = AverageMeter(float) avg_grad_norm = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, _, _ = self.model(x, mag, max_length, repeat=hp.repeat_train, train_step=step) # B, C, F, T step = step + 1 loss = self.calc_loss(output_loss, y, T_ys) # backward self.optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), hp.thr_clip_grad) self.optimizer.step() # print # if np.any(np.isnan(loss.item())): # raise NameError('Loss is Nan!') # for vname,var in self.model.named_parameters(): # if np.any(np.isnan(var.detach().cpu().numpy())): # print("nan detected in %s " % vname) ##import pdb; pdb.set_trace() avg_loss.update(loss.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') avg_grad_norm.update(grad_norm) if i_iter % 25 == 0: self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch * len(loader_train) + i_iter) self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch * len(loader_train) + i_iter) avg_loss = AverageMeter(float) avg_grad_norm = AverageMeter(float) # Validation # loss_valid = self.validate(loader_valid, logdir, epoch) loss_valid = self.validate(loader_valid, logdir, epoch + 1, step, repeat=hp.repeat_train) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save(( self.module.state_dict(), self.optimizer.state_dict(), self.scheduler.state_dict(), ), logdir / f'{epoch+1}.pt') # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close()
def inspect(self, loader: DataLoader, logdir: Path): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ self.model.eval() os.makedirs(Path(logdir), exist_ok=True) self.writer = CustomWriter(str(logdir), group='test') ##import pdb; pdb.set_trace() num_filters = len(self.filters) avg_loss1 = AverageMeter(float) avg_lozz1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_lozz2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] avg_losses_base = [AverageMeter(float) for _ in range(num_filters) ] losses = [None] * num_filters losses_base = [None] * num_filters cnt = 0 pbar = tqdm(enumerate(loader), desc='loss inspection', dynamic_ncols=True) for i_iter, data in pbar: ##import pdb; pdb.set_trace() y = self.preprocess(data) # B, C, F, T x_mel = self.model.spec_to_mel(y) z = self.model.mel_pseudo_inverse(x_mel) T_ys = data['T_ys'] x = self.model(x_mel) # B, C, F, T y_mel = self.model.spec_to_mel(x) z_mel = self.model.spec_to_mel(y) loss1 = self.calc_loss(x, y, T_ys, self.criterion) lozz1 = self.calc_loss(z, y, T_ys, self.criterion) loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2) lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2) loss = loss1 + loss2*hp.l2_factor # for i,f in enumerate(self.filters): # s = self.f_specs[i][1] # losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s ) # loss = loss + losses[i] for i,(k,s) in enumerate(self.f_specs): losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s ) losses_base[i] = self.calc_loss_smooth2(y,y,T_ys,k, s ) loss = loss + losses[i] avg_loss1.update(loss1.item(), len(T_ys)) avg_lozz1.update(lozz1.item(), len(T_ys)) avg_loss2.update(loss2.item(), len(T_ys)) avg_lozz2.update(lozz2.item(), len(T_ys)) avg_loss_tot.update(loss.item(), len(T_ys)) for j,l in enumerate(losses): avg_losses[j].update(l.item(), len(T_ys)) for j,l in enumerate(losses_base): avg_losses_base[j].update(l.item(), len(T_ys)) # print ##pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}') # write summary if 0: for p in range(len(T_ys)): _x = x[p,0,:,:T_ys[p]].cpu() _y = y[p,0,:,:T_ys[p]].cpu() _z = z[p,0,:,:T_ys[p]].cpu() y_wav = data['wav'][p] ymin = _y[_y > 0].min() vmin, vmax = librosa.amplitude_to_db(np.array((ymin, _y.max()))) kwargs_fig = dict(vmin=vmin, vmax=vmax) if hp.request_drawings: fig_x = draw_spectrogram(_x, **kwargs_fig) self.writer.add_figure(f'Audio/1_DNN_Output', fig_x, cnt) fig_y = draw_spectrogram(_y, **kwargs_fig) fig_z = draw_spectrogram(_z, **kwargs_fig) self.writer.add_figure(f'Audio/0_Pseudo_Inverse', fig_z, cnt) self.writer.add_figure(f'Audio/2_Real_Spectrogram', fig_y, cnt) audio_x = self.audio_from_mag_spec(np.abs(_x.numpy())) x_scale = np.abs(audio_x).max() / 0.5 self.writer.add_audio(f'LWS/1_DNN_Output', torch.from_numpy(audio_x / x_scale), cnt, sample_rate=hp.sampling_rate) audio_y = self.audio_from_mag_spec(_y.numpy()) audio_z = self.audio_from_mag_spec(_z.numpy()) z_scale = np.abs(audio_z).max() / 0.5 y_scale = np.abs(audio_y).max() / 0.5 self.writer.add_audio(f'LWS/0_Pseudo_Inverse', torch.from_numpy(audio_z / z_scale), cnt, sample_rate=hp.sampling_rate) self.writer.add_audio(f'LWS/2_Real_Spectrogram', torch.from_numpy(audio_y / y_scale), cnt, sample_rate=hp.sampling_rate) ##import pdb; pdb.set_trace() stoi_scores = {'0_Pseudo_Inverse' : self.calc_stoi(y_wav, audio_z), '1_DNN_Output' : self.calc_stoi(y_wav, audio_x), '2_Real_Spectrogram' : self.calc_stoi(y_wav, audio_y)} self.writer.add_scalars(f'LWS/STOI', stoi_scores, cnt ) # self.writer.add_scalar(f'STOI/0_Pseudo_Inverse_LWS', self.calc_stoi(y_wav, audio_z) , cnt) # self.writer.add_scalar(f'STOI/1_DNN_Output_LWS', self.calc_stoi(y_wav, audio_x) , cnt) # self.writer.add_scalar(f'STOI/2_Real_Spectrogram_LWS', self.calc_stoi(y_wav, audio_y) , cnt) cnt = cnt + 1 for j, avg_loss in enumerate(avg_losses): k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'inspect/losses_breakdown', avg_loss.get_average(), j) for j, avg_loss in enumerate(avg_losses_base): k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'inspect/losses_base_breakdown', avg_loss.get_average(), j) for j, avg_loss in enumerate(avg_losses): avg_loss2 = avg_losses_base[j] k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'inspect/losses_normalized_breakdown', avg_loss2.get_average() / avg_loss.get_average() , j) # self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch) # self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch) # self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch) # self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch) # for j, avg_loss in enumerate(avg_losses): # k = self.f_specs[j][0] # s = self.f_specs[j][1] # self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch) # self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch) self.model.train() return
class Trainer: def __init__(self, path_state_dict=''): ##import pdb; pdb.set_trace() self.writer: Optional[CustomWriter] = None meltrans = create_mel_filterbank( hp.sampling_rate, hp.n_fft, fmin=hp.mel_fmin, fmax=hp.mel_fmax, n_mels=hp.mel_freq) self.model = melGen(self.writer, hp.n_freq, meltrans, hp.mel_generator) count_parameters(self.model) self.module = self.model self.lws_processor = lws.lws(hp.n_fft, hp.l_hop, mode='speech', perfectrec=False) self.prev_stoi_scores = {} self.base_stoi_scores = {} if hp.crit == "l1": self.criterion = nn.L1Loss(reduction='none') elif hp.crit == "l2": self.criterion = nn.L2Loss(reduction='none') else: print("Loss not implemented") return None self.criterion2 = nn.L1Loss(reduction='none') self.f_specs= {0: [(5, 2),(15,5)], 1: [(5, 2)], 2: [(3 ,1)], 3: [(3 ,1),(5, 2 )], 4: [(3 ,1),(5, 2 ), ( 7,3 ) ], 5: [(15 ,5)], 6: [(3 ,1),(5, 2 ), ( 7,3 ), (15,5), (25,10)], 7: [(1 ,1)], 8: [(1 ,1), (3 ,1), (5, 2 ),(15 ,5), ( 7,3 ), (25,10), (9,4), (20,5), (5,3) ] }[hp.loss_mode] self.filters = [gen_filter(k) for k,s in self.f_specs] if hp.optimizer == "adam": self.optimizer = Adam(self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) elif hp.optimizer == "sgd": self.optimizer = SGD(self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) elif hp.optimizer == "radam": self.optimizer = RAdam(self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay, ) elif hp.optimizer == "novograd": self.optimizer = NovoGrad(self.model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay ) elif hp.optimizer == "sm3": raise NameError('sm3 not implemented') else: raise NameError('optimizer not implemented') self.__init_device(hp.device) ##if hp.optimizer == "novograd": ## self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, 200 ,1e-5) ##else: self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, **hp.scheduler) self.max_epochs = hp.n_epochs self.valid_eval_sample: Dict[str, Any] = dict() # len_weight = hp.repeat_train # self.loss_weight = torch.tensor( # [1./i for i in range(len_weight, 0, -1)], # ) # self.loss_weight /= self.loss_weight.sum() # Load State Dict if path_state_dict: st_model, st_optim, st_sched = torch.load(path_state_dict, map_location=self.in_device) try: self.module.load_state_dict(st_model) self.optimizer.load_state_dict(st_optim) self.scheduler.load_state_dict(st_sched) except: raise Exception('The model is different from the state dict.') path_summary = hp.logdir / 'summary.txt' if not path_summary.exists(): # print_to_file( # path_summary, # summary, # (self.model, hp.dummy_input_size), # dict(device=self.str_device[:4]) # ) with path_summary.open('w') as f: f.write('\n') with (hp.logdir / 'hparams.txt').open('w') as f: f.write(repr(hp)) def __init_device(self, device): """ :type device: Union[int, str, Sequence] :type out_device: Union[int, str, Sequence] :return: """ # device type: List[int] if type(device) == int: device = [device] elif type(device) == str: if device[0] == 'a': device = [x for x in range(torch.cuda.device_count())] else: device = [int(d.replace('cuda:', '')) for d in device.split(",")] print("Used devices = %s" % device) else: # sequence of devices if type(device[0]) != int: device = [int(d.replace('cuda:', '')) for d in device] self.num_workers = len(device) if len(device) > 1: self.model = nn.DataParallel(self.model, device_ids=device) self.in_device = torch.device(f'cuda:{device[0]}') torch.cuda.set_device(self.in_device) self.model.cuda() self.criterion.cuda() self.criterion2.cuda() self.filters = [f.cuda() for f in self.filters] def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: # B, F, T, C y = data['y'] y = y.cuda() return y @torch.no_grad() def postprocess(self, output: Tensor, residual: Tensor, Ts: ndarray, idx: int, dataset: ComplexSpecDataset) -> Dict[str, ndarray]: dict_one = dict(out=output, res=residual) for key in dict_one: if dict_one[key] is None: continue one = dict_one[key][idx, :, :, :Ts[idx]] one = one.permute(1, 2, 0).contiguous() # F, T, 2 one = one.cpu().numpy().view(dtype=np.complex64) # F, T, 1 dict_one[key] = one return dict_one def calc_loss(self, x: Tensor, y: Tensor, T_ys: Sequence[int], crit) -> Tensor: """ out_blocks: B, depth, C, F, T y: B, C, F, T """ with warnings.catch_warnings(): warnings.simplefilter('ignore') loss_no_red = crit(x, y) loss_blocks = torch.zeros(x.shape[1], device=y.device) tot =0 for T, loss_batch in zip(T_ys, loss_no_red): tot += T loss_blocks += torch.sum(loss_batch[..., :T]) loss_blocks = loss_blocks / tot if len(loss_blocks) == 1: loss = loss_blocks.squeeze() else: loss = loss_blocks @ self.loss_weight return loss def calc_loss_smooth(self, _x: Tensor, _y: Tensor, T_ys: Sequence[int], filter, stride: int ,pad: int = 0) -> Tensor: """ out_blocks: B, depth, C, F, T y: B, C, F, T """ crit = self.criterion x = F.conv2d(_x, filter, stride = stride) y = F.conv2d(_y, filter, stride = stride) with warnings.catch_warnings(): warnings.simplefilter('ignore') loss_no_red = crit(x, y) loss_blocks = torch.zeros(x.shape[1], device=y.device) tot =0 for T, loss_batch in zip(T_ys, loss_no_red): tot += T loss_blocks += torch.sum(loss_batch[..., :T]) loss_blocks = loss_blocks / tot if len(loss_blocks) == 1: loss = loss_blocks.squeeze() else: loss = loss_blocks @ self.loss_weight return loss def calc_loss_smooth2(self, _x: Tensor, _y: Tensor, T_ys: Sequence[int], kern: int , stride: int ,pad: int = 0) -> Tensor: """ out_blocks: B, depth, C, F, T y: B, C, F, T """ crit = self.criterion x = F.max_pool2d(_x, (kern, 1), stride = stride ) y = F.max_pool2d(_y, (kern, 1), stride = stride ) with warnings.catch_warnings(): warnings.simplefilter('ignore') loss_no_red = crit(x, y) loss_blocks = torch.zeros(x.shape[1], device=y.device) tot =0 for T, loss_batch in zip(T_ys, loss_no_red): tot += T loss_blocks += torch.sum(loss_batch[..., :T]) loss_blocks = loss_blocks / tot if len(loss_blocks) == 1: loss1 = loss_blocks.squeeze() else: loss1 = loss_blocks @ self.loss_weight x = F.max_pool2d(-1*_x, (kern, 1), stride = stride ) y = F.max_pool2d(-1*_y, (kern, 1), stride = stride ) with warnings.catch_warnings(): warnings.simplefilter('ignore') loss_no_red = crit(x, y) loss_blocks = torch.zeros(x.shape[1], device=y.device) tot =0 for T, loss_batch in zip(T_ys, loss_no_red): tot += T loss_blocks += torch.sum(loss_batch[..., :T]) loss_blocks = loss_blocks / tot if len(loss_blocks) == 1: loss2 = loss_blocks.squeeze() else: loss2 = loss_blocks @ self.loss_weight loss = loss1 + loss2 return loss @torch.no_grad() def should_stop(self, loss_valid, epoch): if epoch == self.max_epochs - 1: return True self.scheduler.step(loss_valid) # if self.scheduler.t_epoch == 0: # if it is restarted now # # if self.loss_last_restart < loss_valid: # # return True # if self.loss_last_restart * hp.threshold_stop < loss_valid: # self.max_epochs = epoch + self.scheduler.restart_period + 1 # self.loss_last_restart = loss_valid def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): os.makedirs(Path(logdir), exist_ok=True) self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # Start Training step = 0 loss_valid = self.validate(loader_valid, logdir, 0) l2_factor = hp.l2_factor num_filters = len(self.filters) for epoch in range(first_epoch, hp.n_epochs): self.writer.add_scalar('meta/lr', self.optimizer.param_groups[0]['lr'], epoch) pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] losses = [None] * num_filters avg_grad_norm = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data ##import pdb; pdb.set_trace() y = self.preprocess(data) x_mel = self.model.spec_to_mel(y) T_ys = data['T_ys'] # forward x = self.model(x_mel) y_mel = self.model.spec_to_mel(x) step = step + 1 loss1 = self.calc_loss(x , y , T_ys, self.criterion) loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2) loss = loss1+ l2_factor*loss2 # for i,f in enumerate(self.filters): # s = self.f_specs[i][1] # losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s ) # loss = loss + losses[i] for i,(k,s) in enumerate(self.f_specs): losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s ) loss = loss + losses[i] # backward self.optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), hp.thr_clip_grad) self.optimizer.step() # print avg_loss1.update(loss1.item(), len(T_ys)) avg_loss2.update(loss2.item(), len(T_ys)) avg_loss_tot.update(loss.item(), len(T_ys)) for j,l in enumerate(losses): avg_losses[j].update(l.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}') avg_grad_norm.update(grad_norm) if i_iter % 25 == 0: self.writer.add_scalar('loss/loss1_train', avg_loss1.get_average(), epoch*len(loader_train)+ i_iter) self.writer.add_scalar('loss/loss2_train', avg_loss2.get_average(), epoch*len(loader_train)+ i_iter) for j, avg_loss in enumerate(avg_losses): k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'loss/losses_{k}_{s}_train', avg_loss.get_average(), epoch*len(loader_train)+ i_iter) self.writer.add_scalar('loss/loss_total_train', avg_loss_tot.get_average(), epoch*len(loader_train)+ i_iter) self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch*len(loader_train) + i_iter) avg_loss1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] avg_grad_norm = AverageMeter(float) # Validation # loss_valid = self.validate(loader_valid, logdir, epoch) loss_valid = self.validate(loader_valid, logdir, epoch+1) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save( (self.module.state_dict(), self.optimizer.state_dict(), self.scheduler.state_dict(), ), logdir / f'{epoch+1}.pt' ) # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close() def audio_from_mag_spec(self, mag_spec): mag_spec = mag_spec.astype(np.float64) spec_lws = self.lws_processor.run_lws(np.transpose(mag_spec)) magspec_inv = self.lws_processor.istft(spec_lws)[:, np.newaxis, np.newaxis] magspec_inv = magspec_inv.astype('float32') return magspec_inv @torch.no_grad() def validate(self, loader: DataLoader, logdir: Path, epoch: int): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ self.model.eval() num_filters = len(self.filters) avg_stoi = AverageMeter(float) avg_stoi_norm = AverageMeter(float) avg_stoi_base = AverageMeter(float) avg_loss1 = AverageMeter(float) avg_lozz1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_lozz2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] losses = [None] * num_filters pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True) num_iters = len(pbar) for i_iter, data in enumerate(pbar): ##import pdb; pdb.set_trace() y = self.preprocess(data) # B, C, F, T x_mel = self.model.spec_to_mel(y) z = self.model.mel_pseudo_inverse(x_mel) T_ys = data['T_ys'] x = self.model(x_mel) # B, C, F, T y_mel = self.model.spec_to_mel(x) z_mel = self.model.spec_to_mel(y) loss1 = self.calc_loss(x, y, T_ys, self.criterion) lozz1 = self.calc_loss(z, y, T_ys, self.criterion) loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2) lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2) loss = loss1 + loss2*hp.l2_factor # for i,f in enumerate(self.filters): # s = self.f_specs[i][1] # losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s ) # loss = loss + losses[i] for i,(k,s) in enumerate(self.f_specs): losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s ) loss = loss + losses[i] avg_loss1.update(loss1.item(), len(T_ys)) avg_lozz1.update(lozz1.item(), len(T_ys)) avg_loss2.update(loss2.item(), len(T_ys)) avg_lozz2.update(lozz2.item(), len(T_ys)) avg_loss_tot.update(loss.item(), len(T_ys)) for j,l in enumerate(losses): avg_losses[j].update(l.item(), len(T_ys)) # print pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}') ## STOI evaluation with LWS for p in range(min(hp.num_stoi// num_iters,len(T_ys))): _x = x[p,0,:,:T_ys[p]].cpu() _y = y[p,0,:,:T_ys[p]].cpu() _z = z[p,0,:,:T_ys[p]].cpu() audio_x = self.audio_from_mag_spec(np.abs(_x.numpy())) y_wav = data['wav'][p] stoi_score= self.calc_stoi(y_wav, audio_x) avg_stoi.update(stoi_score) if not i_iter in self.prev_stoi_scores: audio_y = self.audio_from_mag_spec(_y.numpy()) audio_z = self.audio_from_mag_spec(_z.numpy()) self.prev_stoi_scores[i_iter] = self.calc_stoi(y_wav, audio_y) self.base_stoi_scores[i_iter] = self.calc_stoi(y_wav, audio_z) avg_stoi_norm.update( stoi_score / self.prev_stoi_scores[i_iter]) avg_stoi_base.update( stoi_score / self.base_stoi_scores[i_iter]) # write summary ## if i_iter < 4: if False: ## stoi is good enough until tests x = x[0,0,:,:T_ys[0]].cpu() y = y[0,0,:,:T_ys[0]].cpu() z = z[0,0,:,:T_ys[0]].cpu() ##import pdb; pdb.set_trace() if i_iter == 3 and hp.request_drawings: ymin = y[y > 0].min() vmin, vmax = librosa.amplitude_to_db(np.array((ymin, y.max()))) kwargs_fig = dict(vmin=vmin, vmax=vmax) fig_x = draw_spectrogram(x, **kwargs_fig) ##self.add_figure(f'{self.group}Audio{idx}/0_Noisy_Spectrum', fig_x, step) self.writer.add_figure(f'Audio{i_iter}/1_DNN_Output', fig_x, epoch) if epoch ==0: fig_y = draw_spectrogram(y, **kwargs_fig) fig_z = draw_spectrogram(z, **kwargs_fig) self.writer.add_figure(f'Audio{i_iter}/0_Pseudo_Inverse', fig_z, epoch) self.writer.add_figure(f'Audio{i_iter}/2_Real_Spectrogram', fig_y, epoch) else: audio_x = self.audio_from_mag_spec(np.abs(x.numpy())) x_scale = np.abs(audio_x).max() / 0.5 self.writer.add_audio(f'Audio{i_iter}/1_DNN_Output', torch.from_numpy(audio_x / x_scale), epoch, sample_rate=hp.sampling_rate) if epoch ==0: audio_y = self.audio_from_mag_spec(y.numpy()) audio_z = self.audio_from_mag_spec(z.numpy()) z_scale = np.abs(audio_z).max() / 0.5 y_scale = np.abs(audio_y).max() / 0.5 self.writer.add_audio(f'Audio{i_iter}/0_Pseudo_Inverse', torch.from_numpy(audio_z / z_scale), epoch, sample_rate=hp.sampling_rate) self.writer.add_audio(f'Audio{i_iter}/2_Real_Spectrogram', torch.from_numpy(audio_y / y_scale), epoch, sample_rate=hp.sampling_rate) self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch) self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch) self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch) self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch) self.writer.add_scalar(f'valid/STOI', avg_stoi.get_average(), epoch ) self.writer.add_scalar(f'valid/STOI_normalized', avg_stoi_norm.get_average(), epoch ) self.writer.add_scalar(f'valid/STOI_improvement', avg_stoi_base.get_average(), epoch ) for j, avg_loss in enumerate(avg_losses): k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch) self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch) self.model.train() return avg_loss1.get_average() def calc_stoi(self, y_wav, audio): audio_len = min(y_wav.shape[0], audio.shape[0] ) measure = calc_using_eval_module(y_wav[:audio_len], audio[:audio_len,0,0]) return measure['STOI'] @torch.no_grad() def test(self, loader: DataLoader, logdir: Path): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ self.model.eval() os.makedirs(Path(logdir), exist_ok=True) self.writer = CustomWriter(str(logdir), group='test') ##import pdb; pdb.set_trace() num_filters = len(self.filters) avg_loss1 = AverageMeter(float) avg_lozz1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_lozz2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] losses = [None] * num_filters cnt = 0 for i_iter, data in enumerate(loader): ##import pdb; pdb.set_trace() y = self.preprocess(data) # B, C, F, T x_mel = self.model.spec_to_mel(y) z = self.model.mel_pseudo_inverse(x_mel) T_ys = data['T_ys'] x = self.model(x_mel) # B, C, F, T y_mel = self.model.spec_to_mel(x) z_mel = self.model.spec_to_mel(y) loss1 = self.calc_loss(x, y, T_ys, self.criterion) lozz1 = self.calc_loss(z, y, T_ys, self.criterion) loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2) lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2) loss = loss1 + loss2*hp.l2_factor # for i,f in enumerate(self.filters): # s = self.f_specs[i][1] # losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s ) # loss = loss + losses[i] for i,(k,s) in enumerate(self.f_specs): losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s ) loss = loss + losses[i] avg_loss1.update(loss1.item(), len(T_ys)) avg_lozz1.update(lozz1.item(), len(T_ys)) avg_loss2.update(loss2.item(), len(T_ys)) avg_lozz2.update(lozz2.item(), len(T_ys)) avg_loss_tot.update(loss.item(), len(T_ys)) for j,l in enumerate(losses): avg_losses[j].update(l.item(), len(T_ys)) # print ##pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}') # write summary pbar = tqdm(range(len(T_ys)), desc='validate_bath', postfix='[0]', dynamic_ncols=True) for p in pbar: _x = x[p,0,:,:T_ys[p]].cpu() _y = y[p,0,:,:T_ys[p]].cpu() _z = z[p,0,:,:T_ys[p]].cpu() y_wav = data['wav'][p] ymin = _y[_y > 0].min() vmin, vmax = librosa.amplitude_to_db(np.array((ymin, _y.max()))) kwargs_fig = dict(vmin=vmin, vmax=vmax) if hp.request_drawings: fig_x = draw_spectrogram(_x, **kwargs_fig) self.writer.add_figure(f'Audio/1_DNN_Output', fig_x, cnt) fig_y = draw_spectrogram(_y, **kwargs_fig) fig_z = draw_spectrogram(_z, **kwargs_fig) self.writer.add_figure(f'Audio/0_Pseudo_Inverse', fig_z, cnt) self.writer.add_figure(f'Audio/2_Real_Spectrogram', fig_y, cnt) audio_x = self.audio_from_mag_spec(np.abs(_x.numpy())) x_scale = np.abs(audio_x).max() / 0.5 self.writer.add_audio(f'LWS/1_DNN_Output', torch.from_numpy(audio_x / x_scale), cnt, sample_rate=hp.sampling_rate) audio_y = self.audio_from_mag_spec(_y.numpy()) audio_z = self.audio_from_mag_spec(_z.numpy()) z_scale = np.abs(audio_z).max() / 0.5 y_scale = np.abs(audio_y).max() / 0.5 self.writer.add_audio(f'LWS/0_Pseudo_Inverse', torch.from_numpy(audio_z / z_scale), cnt, sample_rate=hp.sampling_rate) self.writer.add_audio(f'LWS/2_Real_Spectrogram', torch.from_numpy(audio_y / y_scale), cnt, sample_rate=hp.sampling_rate) ##import pdb; pdb.set_trace() stoi_scores = {'0_Pseudo_Inverse' : self.calc_stoi(y_wav, audio_z), '1_DNN_Output' : self.calc_stoi(y_wav, audio_x), '2_Real_Spectrogram' : self.calc_stoi(y_wav, audio_y)} self.writer.add_scalars(f'LWS/STOI', stoi_scores, cnt ) # self.writer.add_scalar(f'STOI/0_Pseudo_Inverse_LWS', self.calc_stoi(y_wav, audio_z) , cnt) # self.writer.add_scalar(f'STOI/1_DNN_Output_LWS', self.calc_stoi(y_wav, audio_x) , cnt) # self.writer.add_scalar(f'STOI/2_Real_Spectrogram_LWS', self.calc_stoi(y_wav, audio_y) , cnt) cnt = cnt + 1 # self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch) # self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch) # self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch) # self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch) # for j, avg_loss in enumerate(avg_losses): # k = self.f_specs[j][0] # s = self.f_specs[j][1] # self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch) # self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch) self.model.train() return @torch.no_grad() def inspect(self, loader: DataLoader, logdir: Path): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ self.model.eval() os.makedirs(Path(logdir), exist_ok=True) self.writer = CustomWriter(str(logdir), group='test') ##import pdb; pdb.set_trace() num_filters = len(self.filters) avg_loss1 = AverageMeter(float) avg_lozz1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_lozz2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] avg_losses_base = [AverageMeter(float) for _ in range(num_filters) ] losses = [None] * num_filters losses_base = [None] * num_filters cnt = 0 pbar = tqdm(enumerate(loader), desc='loss inspection', dynamic_ncols=True) for i_iter, data in pbar: ##import pdb; pdb.set_trace() y = self.preprocess(data) # B, C, F, T x_mel = self.model.spec_to_mel(y) z = self.model.mel_pseudo_inverse(x_mel) T_ys = data['T_ys'] x = self.model(x_mel) # B, C, F, T y_mel = self.model.spec_to_mel(x) z_mel = self.model.spec_to_mel(y) loss1 = self.calc_loss(x, y, T_ys, self.criterion) lozz1 = self.calc_loss(z, y, T_ys, self.criterion) loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2) lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2) loss = loss1 + loss2*hp.l2_factor # for i,f in enumerate(self.filters): # s = self.f_specs[i][1] # losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s ) # loss = loss + losses[i] for i,(k,s) in enumerate(self.f_specs): losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s ) losses_base[i] = self.calc_loss_smooth2(y,y,T_ys,k, s ) loss = loss + losses[i] avg_loss1.update(loss1.item(), len(T_ys)) avg_lozz1.update(lozz1.item(), len(T_ys)) avg_loss2.update(loss2.item(), len(T_ys)) avg_lozz2.update(lozz2.item(), len(T_ys)) avg_loss_tot.update(loss.item(), len(T_ys)) for j,l in enumerate(losses): avg_losses[j].update(l.item(), len(T_ys)) for j,l in enumerate(losses_base): avg_losses_base[j].update(l.item(), len(T_ys)) # print ##pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}') # write summary if 0: for p in range(len(T_ys)): _x = x[p,0,:,:T_ys[p]].cpu() _y = y[p,0,:,:T_ys[p]].cpu() _z = z[p,0,:,:T_ys[p]].cpu() y_wav = data['wav'][p] ymin = _y[_y > 0].min() vmin, vmax = librosa.amplitude_to_db(np.array((ymin, _y.max()))) kwargs_fig = dict(vmin=vmin, vmax=vmax) if hp.request_drawings: fig_x = draw_spectrogram(_x, **kwargs_fig) self.writer.add_figure(f'Audio/1_DNN_Output', fig_x, cnt) fig_y = draw_spectrogram(_y, **kwargs_fig) fig_z = draw_spectrogram(_z, **kwargs_fig) self.writer.add_figure(f'Audio/0_Pseudo_Inverse', fig_z, cnt) self.writer.add_figure(f'Audio/2_Real_Spectrogram', fig_y, cnt) audio_x = self.audio_from_mag_spec(np.abs(_x.numpy())) x_scale = np.abs(audio_x).max() / 0.5 self.writer.add_audio(f'LWS/1_DNN_Output', torch.from_numpy(audio_x / x_scale), cnt, sample_rate=hp.sampling_rate) audio_y = self.audio_from_mag_spec(_y.numpy()) audio_z = self.audio_from_mag_spec(_z.numpy()) z_scale = np.abs(audio_z).max() / 0.5 y_scale = np.abs(audio_y).max() / 0.5 self.writer.add_audio(f'LWS/0_Pseudo_Inverse', torch.from_numpy(audio_z / z_scale), cnt, sample_rate=hp.sampling_rate) self.writer.add_audio(f'LWS/2_Real_Spectrogram', torch.from_numpy(audio_y / y_scale), cnt, sample_rate=hp.sampling_rate) ##import pdb; pdb.set_trace() stoi_scores = {'0_Pseudo_Inverse' : self.calc_stoi(y_wav, audio_z), '1_DNN_Output' : self.calc_stoi(y_wav, audio_x), '2_Real_Spectrogram' : self.calc_stoi(y_wav, audio_y)} self.writer.add_scalars(f'LWS/STOI', stoi_scores, cnt ) # self.writer.add_scalar(f'STOI/0_Pseudo_Inverse_LWS', self.calc_stoi(y_wav, audio_z) , cnt) # self.writer.add_scalar(f'STOI/1_DNN_Output_LWS', self.calc_stoi(y_wav, audio_x) , cnt) # self.writer.add_scalar(f'STOI/2_Real_Spectrogram_LWS', self.calc_stoi(y_wav, audio_y) , cnt) cnt = cnt + 1 for j, avg_loss in enumerate(avg_losses): k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'inspect/losses_breakdown', avg_loss.get_average(), j) for j, avg_loss in enumerate(avg_losses_base): k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'inspect/losses_base_breakdown', avg_loss.get_average(), j) for j, avg_loss in enumerate(avg_losses): avg_loss2 = avg_losses_base[j] k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'inspect/losses_normalized_breakdown', avg_loss2.get_average() / avg_loss.get_average() , j) # self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch) # self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch) # self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch) # self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch) # for j, avg_loss in enumerate(avg_losses): # k = self.f_specs[j][0] # s = self.f_specs[j][1] # self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch) # self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch) self.model.train() return @torch.no_grad() def infer(self, loader: DataLoader, logdir: Path): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ def save_feature(num_snr, i_speech: int, s_path_speech: str, speech: ndarray, mag_mel2spec) -> tuple: spec_clean = np.ascontiguousarray(librosa.stft(speech, **hp.kwargs_stft)) mag_clean = np.ascontiguousarray(np.abs(spec_clean)[..., np.newaxis]) signal_power = np.mean(np.abs(speech)**2) list_dict = [] list_snr_db = [] for _ in enumerate(range(num_snr)): snr_db = -6*np.random.rand() list_snr_db.append(snr_db) snr = librosa.db_to_power(snr_db) noise_power = signal_power / snr noisy = speech + np.sqrt(noise_power) * np.random.randn(len(speech)) spec_noisy = librosa.stft(noisy, **hp.kwargs_stft) spec_noisy = np.ascontiguousarray(spec_noisy) list_dict.append( dict(spec_noisy=spec_noisy, speech=speech, spec_clean=spec_clean, mag_clean=mag_mel2spec, path_speech=s_path_speech, length=len(speech), ) ) return list_snr_db, list_dict self.model.eval() os.makedirs(Path(logdir), exist_ok=True) ##import pdb; pdb.set_trace() cnt = 0 pbar = tqdm(loader, desc='mel2inference', postfix='[0]', dynamic_ncols=True) form= '{:05d}_mel2spec_{:+.2f}dB.npz' num_snr = hp.num_snr for i_iter, data in enumerate(pbar): ##import pdb; pdb.set_trace() y = self.preprocess(data) # B, C, F, T x_mel = self.model.spec_to_mel(y) T_ys = data['T_ys'] x = self.model(x_mel) # B, C, F, T for p in range(len(T_ys)): _x = x[p,0,:,:T_ys[p]].unsqueeze(2).cpu().numpy() ##import pdb; pdb.set_trace() speech = data['wav'][p].numpy() list_snr_db, list_dict = save_feature(num_snr, cnt, data['path_speech'][p] , speech, _x) cnt = cnt + 1 for snr_db, dict_result in zip(list_snr_db, list_dict): np.savez(logdir / form.format(cnt, snr_db), **dict_result, ) self.model.train() return
def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): os.makedirs(Path(logdir), exist_ok=True) self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # Start Training step = 0 loss_valid = self.validate(loader_valid, logdir, 0) l2_factor = hp.l2_factor num_filters = len(self.filters) for epoch in range(first_epoch, hp.n_epochs): self.writer.add_scalar('meta/lr', self.optimizer.param_groups[0]['lr'], epoch) pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] losses = [None] * num_filters avg_grad_norm = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data ##import pdb; pdb.set_trace() y = self.preprocess(data) x_mel = self.model.spec_to_mel(y) T_ys = data['T_ys'] # forward x = self.model(x_mel) y_mel = self.model.spec_to_mel(x) step = step + 1 loss1 = self.calc_loss(x , y , T_ys, self.criterion) loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2) loss = loss1+ l2_factor*loss2 # for i,f in enumerate(self.filters): # s = self.f_specs[i][1] # losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s ) # loss = loss + losses[i] for i,(k,s) in enumerate(self.f_specs): losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s ) loss = loss + losses[i] # backward self.optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), hp.thr_clip_grad) self.optimizer.step() # print avg_loss1.update(loss1.item(), len(T_ys)) avg_loss2.update(loss2.item(), len(T_ys)) avg_loss_tot.update(loss.item(), len(T_ys)) for j,l in enumerate(losses): avg_losses[j].update(l.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}') avg_grad_norm.update(grad_norm) if i_iter % 25 == 0: self.writer.add_scalar('loss/loss1_train', avg_loss1.get_average(), epoch*len(loader_train)+ i_iter) self.writer.add_scalar('loss/loss2_train', avg_loss2.get_average(), epoch*len(loader_train)+ i_iter) for j, avg_loss in enumerate(avg_losses): k = self.f_specs[j][0] s = self.f_specs[j][1] self.writer.add_scalar(f'loss/losses_{k}_{s}_train', avg_loss.get_average(), epoch*len(loader_train)+ i_iter) self.writer.add_scalar('loss/loss_total_train', avg_loss_tot.get_average(), epoch*len(loader_train)+ i_iter) self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch*len(loader_train) + i_iter) avg_loss1 = AverageMeter(float) avg_loss2 = AverageMeter(float) avg_loss_tot = AverageMeter(float) avg_losses = [AverageMeter(float) for _ in range(num_filters) ] avg_grad_norm = AverageMeter(float) # Validation # loss_valid = self.validate(loader_valid, logdir, epoch) loss_valid = self.validate(loader_valid, logdir, epoch+1) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save( (self.module.state_dict(), self.optimizer.state_dict(), self.scheduler.state_dict(), ), logdir / f'{epoch+1}.pt' ) # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close()