def test(self, loader: DataLoader, logdir: Path):
        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output = self.model(x)  # [..., :y.shape[-1]]

            # write summary
            one_sample = DirSpecDataset.decollate_padded(data, 0)  # F, T, C

            out_one = self.postprocess(output, T_ys, 0, loader.dataset)

            DirSpecDataset.save_dirspec(
                logdir / hp.form_result.format(i_iter),
                **one_sample, **out_one
            )

            measure = self.writer.write_one(i_iter, **out_one, **one_sample)
            if avg_measure is None:
                avg_measure = AverageMeter(init_value=measure, init_count=len(T_ys))
            else:
                avg_measure.update(measure)
            # print
            # str_measure = arr2str(measure).replace('\n', '; ')
            # pbar.write(str_measure)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text(f'{group}/Average Measure/Proposed', str(avg_measure[0]))
        self.writer.add_text(f'{group}/Average Measure/Reverberant', str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')
    def postprocess(self, output: Tensor, Ts: ndarray, idx: int,
                    dataset: DirSpecDataset) -> Dict[str, ndarray]:
        one = output[idx, :, :, :Ts[idx]]
        if self.model_name.startswith('UNet'):
            one = one.permute(1, 2, 0)  # F, T, C

        one = dataset.denormalize_(y=one)
        one = one.cpu().numpy()

        return dict(out=one)
    def validate(self, loader: DataLoader, logdir: Path, epoch: int):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """

        self.model.eval()

        avg_loss = AverageMeter(float)

        pbar = tqdm(loader,
                    desc='validate ',
                    postfix='[0]',
                    dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output = self.model(x)[..., :y.shape[-1]]

            # loss
            loss = self.calc_loss(output, y, T_ys)
            avg_loss.update(loss.item(), len(T_ys))

            # print
            pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            # write summary
            if i_iter == 0:
                # F, T, C
                if epoch == 0:
                    one_sample = DirSpecDataset.decollate_padded(data, 0)
                else:
                    one_sample = dict()

                out_one = self.postprocess(output, T_ys, 0, loader.dataset)

                # DirSpecDataset.save_dirspec(
                #     logdir / hp.form_result.format(epoch),
                #     **one_sample, **out_one
                # )

                self.writer.write_one(epoch, **one_sample, **out_one)

        self.writer.add_scalar('loss/valid', avg_loss.get_average(), epoch)

        self.model.train()

        return avg_loss.get_average()
    logdir_save /= foldername

    os.makedirs(logdir_save, exist_ok=True)
    # hp.batch_size /= 2

# epoch, state dict
first_epoch = args.epoch + 1
if first_epoch > 0:
    path_state_dict = logdir_train / f'{args.epoch}.pt'
    if not path_state_dict.exists():
        raise FileNotFoundError(path_state_dict)
else:
    path_state_dict = None

# Training + Validation Set
dataset_temp = DirSpecDataset('train')
dataset_train, dataset_valid = DirSpecDataset.split(dataset_temp,
                                                    (hp.train_ratio, -1))
dataset_train.set_needs(**(hp.channels if not args.save else hp.channels_w_ph))
dataset_valid.set_needs(**hp.channels_w_ph)

loader_train = DataLoader(
    dataset_train,
    batch_size=hp.batch_size,
    num_workers=hp.num_workers,
    collate_fn=dataset_train.pad_collate,
    pin_memory=(hp.device != 'cpu'),
    shuffle=(not args.save),
)
loader_valid = DataLoader(
    dataset_valid,
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            """ save forward propagation data

            """
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        # register hook to save output of each block
        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            if isinstance(self.model, nn.DataParallel):
                module = self.model.module
            else:
                module = self.model
            for sub in module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            if 0 < hp.n_save_block_outs == i_iter:
                break
            output = self.model(x)  # [..., :y.shape[-1]]

            # write summary
            one_sample = DirSpecDataset.decollate_padded(data, 0)  # F, T, C
            out_one = self.postprocess(output, T_ys, 0, loader.dataset)

            # DirSpecDataset.save_dirspec(
            #     logdir / hp.form_result.format(i_iter),
            #     **one_sample, **out_one
            # )

            measure = self.writer.write_one(
                i_iter,
                eval_with_y_ph=hp.eval_with_y_ph,
                **out_one,
                **one_sample,
            )
            if avg_measure is None:
                avg_measure = AverageMeter(init_value=measure,
                                           init_count=len(T_ys))
            else:
                avg_measure.update(measure)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text('Average Measure/Proposed', str(avg_measure[0]))
        self.writer.add_text('Average Measure/Reverberant',
                             str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')
path_fig = path_root / 'figures'
path_fig.mkdir(exist_ok=True)

# %% hp
hp.init_dependent_vars()

# %% model & dataset
# device = 'cuda:0'  # can't be run on cuda due to out of memory
device = 'cpu'
model = UNet(4, 1, 64, 4).to(device)
state_dict = torch.load(path_state_dict, map_location=device)[0]
model.load_state_dict(state_dict)

# Dataset
dataset_temp = DirSpecDataset('train')
dataset_test = DirSpecDataset(kind,
                              dataset_temp.norm_modules,
                              **hp.channels_w_ph)

# %% retrieve data
data = dataset_test.pad_collate([dataset_test[idx_sample]])

x, y = data['normalized_x'], data['normalized_y']
x, y = x.to(device), y.to(device)
y_denorm = data['y']
y_denorm = y_denorm.permute(0, 3, 1, 2)  # B, C, F, T

x.requires_grad = True

baseline = torch.zeros_like(data['x'])