else: path_state_dict = None # Training + Validation Set # run if args.infer and path_state_dict is None: trainer = None num_workers = 1 else: os.makedirs(logdir_train, exist_ok=True) trainer = Trainer(path_state_dict) num_workers = trainer.num_workers if args.train: dataset_train = ComplexSpecDataset('train') dataset_valid = ComplexSpecDataset('valid') dataset_train.set_needs(**hp.channels) ##dataset_valid.set_needs(**hp.channels) loader_train = DataLoader( dataset_train, batch_size=hp.batch_size, num_workers=num_workers, collate_fn=dataset_train.pad_collate, pin_memory=(hp.device != 'cpu'), shuffle=True, drop_last=True, ) loader_valid = DataLoader( dataset_valid,
def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer for sub in self.module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) pbar = tqdm(loader, desc=group, dynamic_ncols=True) cnt_sample = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) if 0 < hp.n_save_block_outs == i_iter: break _, output, residual = self.model(x, mag, max_length, repeat=hp.repeat_test) # write summary for i_b in range(len(T_ys)): i_sample = cnt_sample + i_b one_sample = ComplexSpecDataset.decollate_padded( data, i_b) # F, T, C out_one = self.postprocess(output, residual, T_ys, i_b, loader.dataset) ComplexSpecDataset.save_dirspec( logdir / hp.form_result.format(i_sample), **one_sample, **out_one) measure = self.writer.write_one(i_sample, **out_one, **one_sample, suffix=f'_{hp.repeat_test}') if avg_measure is None: avg_measure = AverageMeter(init_value=measure) else: avg_measure.update(measure) # print # str_measure = arr2str(measure).replace('\n', '; ') # pbar.write(str_measure) cnt_sample += len(T_ys) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text(f'{group}/Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text(f'{group}/Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}')
def validate(self, loader: DataLoader, logdir: Path, epoch: int, repeat=1): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ suffix = f'_{repeat}' if repeat > 1 else '' self.model.eval() avg_loss = AverageMeter(float) pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, output, residual = self.model(x, mag, max_length, repeat=repeat) # loss loss = self.calc_loss(output_loss, y, T_ys) avg_loss.update(loss.item(), len(T_ys)) # print pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') # write summary if i_iter == 0: # F, T, C if not self.valid_eval_sample: self.valid_eval_sample = ComplexSpecDataset.decollate_padded( data, 0) out_one = self.postprocess(output, residual, T_ys, 0, loader.dataset) # ComplexSpecDataset.save_dirspec( # logdir / hp.form_result.format(epoch), # **self.valid_eval_sample, **out_one # ) if not self.writer.reused_sample: one_sample = self.valid_eval_sample else: one_sample = dict() self.writer.write_one(epoch, **one_sample, **out_one, suffix=suffix) self.writer.add_scalar(f'loss/valid{suffix}', avg_loss.get_average(), epoch) self.model.train() return avg_loss.get_average()
os.makedirs(logdir_test) else: exit() os.makedirs(logdir_test, exist_ok=True) # epoch, state dict first_epoch = args.epoch + 1 if first_epoch > 0: path_state_dict = logdir_train / f'{args.epoch}.pt' if not path_state_dict.exists(): raise FileNotFoundError(path_state_dict) else: path_state_dict = None # Training + Validation Set dataset_temp = ComplexSpecDataset('train') dataset_train, dataset_valid = ComplexSpecDataset.split( dataset_temp, (hp.train_ratio, -1)) dataset_train.set_needs(**hp.channels) dataset_valid.set_needs(**hp.channels) # run trainer = Trainer(path_state_dict) if args.train: loader_train = DataLoader( dataset_train, batch_size=hp.batch_size, num_workers=hp.num_workers, collate_fn=dataset_train.pad_collate, pin_memory=(hp.device != 'cpu'), shuffle=True,
def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] if self.writer is None: self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() depth = hp.model['depth'] module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer for sub in self.module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) ##pbar = tqdm(loader, desc=group, dynamic_ncols=True) cnt_sample = 0 for i_iter, data in enumerate(loader): sampleDict = {} # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T if hp.noisy_init: x = torch.normal(0, 1, x.shape).cuda(self.in_device) T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) # if 0 < hp.n_save_block_outs == i_iter: # break repeats = 1 for _ in range(3): _, output, residual = self.model(x, mag, max_length, repeat=1, train_step=1) ##warn up! _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) while repeats <= hp.repeat_test: stime = ms() _, output, residual = self.model(x, mag, max_length, repeat=repeats, train_step=1) avg_measure = AverageMeter() avg_measure2 = AverageMeter() etime = ms(stime) speed = (max_length / hp.fs) * len(T_ys) / (etime / 1000) ##print("degli: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed)) ##self.writer.add_scalar("Test Performance/degli", speed, repeats) # write summary for i_b in tqdm(range(len(T_ys)), desc="degli, %d repeats" % repeats, dynamic_ncols=True): i_sample = cnt_sample + i_b if not i_b in sampleDict: one_sample = ComplexSpecDataset.decollate_padded( data, i_b) reused_sample, result_eval_glim = self.writer.write_zero( 0, i_b, **one_sample, suffix="Base stats") sampleDict[i_b] = (reused_sample, result_eval_glim) sampleItem = sampleDict[i_b] reused_sample = sampleItem[0] result_eval_glim = sampleItem[1] out_one = self.postprocess(output, residual, T_ys, i_b, loader.dataset) # ComplexSpecDataset.save_dirspec( # logdir / hp.form_result.format(i_sample), # **one_sample, **out_one # ) measure = self.writer.write_one(repeats, i_b, result_eval_glim, reused_sample, **out_one, suffix="3_deGLI") avg_measure.update(measure) stime = ms() _, output = self.model.plain_gla(x, mag, max_length, repeat=repeats) etime = ms(stime) speed = (1000 * max_length / hp.fs) * len(T_ys) / (etime) ##print("pure gla: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed)) ##self.writer.add_scalar("Test Performance/gla", speed, repeats) # write summary for i_b in tqdm(range(len(T_ys)), desc="GLA, %d repeats" % repeats, dynamic_ncols=True): i_sample = cnt_sample + i_b sampleItem = sampleDict[i_b] reused_sample = sampleItem[0] result_eval_glim = sampleItem[1] out_one = self.postprocess(output, None, T_ys, i_b, loader.dataset) measure = self.writer.write_one(repeats, i_b, result_eval_glim, reused_sample, **out_one, suffix="4_GLA") avg_measure2.update(measure) cnt_sample += len(T_ys) self.writer.add_scalar(f'STOI/Average Measure/deGLI', avg_measure.get_average()[0, 0], repeats * depth) self.writer.add_scalar(f'STOI/Average Measure/GLA', avg_measure2.get_average()[0, 0], repeats * depth) self.writer.add_scalar(f'STOI/Average Measure/deGLI_semilogx', avg_measure.get_average()[0, 0], int(repeats * depth).bit_length()) self.writer.add_scalar(f'STOI/Average Measure/GLA_semilogx', avg_measure2.get_average()[0, 0], int(repeats * depth).bit_length()) self.writer.add_scalar(f'PESQ/Average Measure/deGLI', avg_measure.get_average()[0, 1], repeats * depth) self.writer.add_scalar(f'PESQ/Average Measure/GLA', avg_measure2.get_average()[0, 1], repeats * depth) self.writer.add_scalar(f'PESQ/Average Measure/deGLI_semilogx', avg_measure.get_average()[0, 1], int(repeats * depth).bit_length()) self.writer.add_scalar(f'PESQ/Average Measure/GLA_semilogx', avg_measure2.get_average()[0, 1], int(repeats * depth).bit_length()) repeats = repeats * 2 break self.model.train() self.writer.close() # Explicitly close