def test(self, loader: DataLoader, logdir: Path): group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() pbar = tqdm(loader, desc=group, dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output = self.model(x) # [..., :y.shape[-1]] # write summary one_sample = DirSpecDataset.decollate_padded(data, 0) # F, T, C out_one = self.postprocess(output, T_ys, 0, loader.dataset) DirSpecDataset.save_dirspec( logdir / hp.form_result.format(i_iter), **one_sample, **out_one ) measure = self.writer.write_one(i_iter, **out_one, **one_sample) if avg_measure is None: avg_measure = AverageMeter(init_value=measure, init_count=len(T_ys)) else: avg_measure.update(measure) # print # str_measure = arr2str(measure).replace('\n', '; ') # pbar.write(str_measure) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text(f'{group}/Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text(f'{group}/Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}')
def postprocess(self, output: Tensor, Ts: ndarray, idx: int, dataset: DirSpecDataset) -> Dict[str, ndarray]: one = output[idx, :, :, :Ts[idx]] if self.model_name.startswith('UNet'): one = one.permute(1, 2, 0) # F, T, C one = dataset.denormalize_(y=one) one = one.cpu().numpy() return dict(out=one)
def validate(self, loader: DataLoader, logdir: Path, epoch: int): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ self.model.eval() avg_loss = AverageMeter(float) pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output = self.model(x)[..., :y.shape[-1]] # loss loss = self.calc_loss(output, y, T_ys) avg_loss.update(loss.item(), len(T_ys)) # print pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') # write summary if i_iter == 0: # F, T, C if epoch == 0: one_sample = DirSpecDataset.decollate_padded(data, 0) else: one_sample = dict() out_one = self.postprocess(output, T_ys, 0, loader.dataset) # DirSpecDataset.save_dirspec( # logdir / hp.form_result.format(epoch), # **one_sample, **out_one # ) self.writer.write_one(epoch, **one_sample, **out_one) self.writer.add_scalar('loss/valid', avg_loss.get_average(), epoch) self.model.train() return avg_loss.get_average()
logdir_save /= foldername os.makedirs(logdir_save, exist_ok=True) # hp.batch_size /= 2 # epoch, state dict first_epoch = args.epoch + 1 if first_epoch > 0: path_state_dict = logdir_train / f'{args.epoch}.pt' if not path_state_dict.exists(): raise FileNotFoundError(path_state_dict) else: path_state_dict = None # Training + Validation Set dataset_temp = DirSpecDataset('train') dataset_train, dataset_valid = DirSpecDataset.split(dataset_temp, (hp.train_ratio, -1)) dataset_train.set_needs(**(hp.channels if not args.save else hp.channels_w_ph)) dataset_valid.set_needs(**hp.channels_w_ph) loader_train = DataLoader( dataset_train, batch_size=hp.batch_size, num_workers=hp.num_workers, collate_fn=dataset_train.pad_collate, pin_memory=(hp.device != 'cpu'), shuffle=(not args.save), ) loader_valid = DataLoader( dataset_valid,
def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): """ save forward propagation data """ module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() # register hook to save output of each block module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer if isinstance(self.model, nn.DataParallel): module = self.model.module else: module = self.model for sub in module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) pbar = tqdm(loader, desc=group, dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) if 0 < hp.n_save_block_outs == i_iter: break output = self.model(x) # [..., :y.shape[-1]] # write summary one_sample = DirSpecDataset.decollate_padded(data, 0) # F, T, C out_one = self.postprocess(output, T_ys, 0, loader.dataset) # DirSpecDataset.save_dirspec( # logdir / hp.form_result.format(i_iter), # **one_sample, **out_one # ) measure = self.writer.write_one( i_iter, eval_with_y_ph=hp.eval_with_y_ph, **out_one, **one_sample, ) if avg_measure is None: avg_measure = AverageMeter(init_value=measure, init_count=len(T_ys)) else: avg_measure.update(measure) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text('Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text('Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}')
path_fig = path_root / 'figures' path_fig.mkdir(exist_ok=True) # %% hp hp.init_dependent_vars() # %% model & dataset # device = 'cuda:0' # can't be run on cuda due to out of memory device = 'cpu' model = UNet(4, 1, 64, 4).to(device) state_dict = torch.load(path_state_dict, map_location=device)[0] model.load_state_dict(state_dict) # Dataset dataset_temp = DirSpecDataset('train') dataset_test = DirSpecDataset(kind, dataset_temp.norm_modules, **hp.channels_w_ph) # %% retrieve data data = dataset_test.pad_collate([dataset_test[idx_sample]]) x, y = data['normalized_x'], data['normalized_y'] x, y = x.to(device), y.to(device) y_denorm = data['y'] y_denorm = y_denorm.permute(0, 3, 1, 2) # B, C, F, T x.requires_grad = True baseline = torch.zeros_like(data['x'])