def test(self, loader: DataLoader, logdir: Path): def save_forward(module: nn.Module, in_: Tensor, out: Tensor): module_name = str(module).split('(')[0] dict_to_save = dict() # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze() dict_to_save['out'] = out.detach().cpu().numpy().squeeze() i_module = module_counts[module_name] for i, o in enumerate(dict_to_save['out']): save_forward.writer.add_figure( f'{group}/blockout_{i_iter}/{module_name}{i_module}', draw_spectrogram(o, to_db=False), i, ) scio.savemat( str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'), dict_to_save, ) module_counts[module_name] += 1 group = logdir.name.split('_')[0] self.writer = CustomWriter(str(logdir), group=group) avg_measure = None self.model.eval() module_counts = None if hp.n_save_block_outs: module_counts = defaultdict(int) save_forward.writer = self.writer for sub in self.module.children(): if isinstance(sub, nn.ModuleList): for m in sub: m.register_forward_hook(save_forward) elif isinstance(sub, nn.ModuleDict): for m in sub.values(): m.register_forward_hook(save_forward) else: sub.register_forward_hook(save_forward) pbar = tqdm(loader, desc=group, dynamic_ncols=True) cnt_sample = 0 for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward if module_counts is not None: module_counts = defaultdict(int) if 0 < hp.n_save_block_outs == i_iter: break _, output, residual = self.model(x, mag, max_length, repeat=hp.repeat_test) # write summary for i_b in range(len(T_ys)): i_sample = cnt_sample + i_b one_sample = ComplexSpecDataset.decollate_padded( data, i_b) # F, T, C out_one = self.postprocess(output, residual, T_ys, i_b, loader.dataset) ComplexSpecDataset.save_dirspec( logdir / hp.form_result.format(i_sample), **one_sample, **out_one) measure = self.writer.write_one(i_sample, **out_one, **one_sample, suffix=f'_{hp.repeat_test}') if avg_measure is None: avg_measure = AverageMeter(init_value=measure) else: avg_measure.update(measure) # print # str_measure = arr2str(measure).replace('\n', '; ') # pbar.write(str_measure) cnt_sample += len(T_ys) self.model.train() avg_measure = avg_measure.get_average() self.writer.add_text(f'{group}/Average Measure/Proposed', str(avg_measure[0])) self.writer.add_text(f'{group}/Average Measure/Reverberant', str(avg_measure[1])) self.writer.close() # Explicitly close print() str_avg_measure = arr2str(avg_measure).replace('\n', '; ') print(f'Average: {str_avg_measure}')
def train(self, loader_train: DataLoader, loader_valid: DataLoader, logdir: Path, first_epoch=0): self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch) # Start Training for epoch in range(first_epoch, hp.n_epochs): self.writer.add_scalar('loss/lr', self.optimizer.param_groups[0]['lr'], epoch) print() pbar = tqdm(loader_train, desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True) avg_loss = AverageMeter(float) avg_grad_norm = AverageMeter(float) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, _, _ = self.model( x, mag, max_length, repeat=hp.repeat_train) # B, C, F, T loss = self.calc_loss(output_loss, y, T_ys) # backward self.optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), hp.thr_clip_grad) self.optimizer.step() # print avg_loss.update(loss.item(), len(T_ys)) pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') avg_grad_norm.update(grad_norm) self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch) self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch) # Validation # loss_valid = self.validate(loader_valid, logdir, epoch) loss_valid = self.validate(loader_valid, logdir, epoch, repeat=hp.repeat_train) # save loss & model if epoch % hp.period_save_state == hp.period_save_state - 1: torch.save(( self.module.state_dict(), self.optimizer.state_dict(), self.scheduler.state_dict(), ), logdir / f'{epoch}.pt') # Early stopping if self.should_stop(loss_valid, epoch): break self.writer.close()
def validate(self, loader: DataLoader, logdir: Path, epoch: int, repeat=1): """ Evaluate the performance of the model. :param loader: DataLoader to use. :param logdir: path of the result files. :param epoch: """ suffix = f'_{repeat}' if repeat > 1 else '' self.model.eval() avg_loss = AverageMeter(float) pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True) for i_iter, data in enumerate(pbar): # get data x, mag, max_length, y = self.preprocess(data) # B, C, F, T T_ys = data['T_ys'] # forward output_loss, output, residual = self.model(x, mag, max_length, repeat=repeat) # loss loss = self.calc_loss(output_loss, y, T_ys) avg_loss.update(loss.item(), len(T_ys)) # print pbar.set_postfix_str(f'{avg_loss.get_average():.1e}') # write summary if i_iter == 0: # F, T, C if not self.valid_eval_sample: self.valid_eval_sample = ComplexSpecDataset.decollate_padded( data, 0) out_one = self.postprocess(output, residual, T_ys, 0, loader.dataset) # ComplexSpecDataset.save_dirspec( # logdir / hp.form_result.format(epoch), # **self.valid_eval_sample, **out_one # ) if not self.writer.reused_sample: one_sample = self.valid_eval_sample else: one_sample = dict() self.writer.write_one(epoch, **one_sample, **out_one, suffix=suffix) self.writer.add_scalar(f'loss/valid{suffix}', avg_loss.get_average(), epoch) self.model.train() return avg_loss.get_average()
feat0 = criterion_cons(fake_features[0], real_features[0]) * 3. feat1 = criterion_cons(fake_features[1], real_features[1]) * 2.5 feat2 = criterion_cons(fake_features[2], real_features[2]) * 2. feat3 = criterion_cons(fake_features[3], real_features[3]) * 1.5 feat4 = criterion_cons(fake_features[4], real_features[4]) * 1. loss_feat = (feat0 + feat1 + feat2 + feat3 + feat4) optimizer_G.zero_grad() loss_feat.backward() optimizer_G.step() running_loss_feat.update(loss_feat, image.size(0)) if global_step % steps_per_epoch == 0: epoch_loss_GD = running_loss_GD.get_average() epoch_loss_D = running_loss_D.get_average() epoch_loss_cons = running_loss_cons.get_average() epoch_loss_feat = running_loss_feat.get_average() running_loss_GD.reset() running_loss_D.reset() running_loss_cons.reset() running_loss_feat.reset() msg = "epoch- %d, loss_GD- %.4f, loss_cons- %.4f, loss_feat- %.4f, loss_D- %.4f" % ( global_epoch, epoch_loss_GD, epoch_loss_cons, epoch_loss_feat, epoch_loss_D) logger.info(msg)