Example #1
0
class Trainer:
    def __init__(self, path_state_dict=''):
        self.model = DeGLI(**hp.model)
        self.module = self.model
        self.criterion = nn.L1Loss(reduction='none')
        self.optimizer = Adam(
            self.model.parameters(),
            lr=hp.learning_rate,
            weight_decay=hp.weight_decay,
        )

        self.__init_device(hp.device, hp.out_device)

        self.scheduler = lr_scheduler.ReduceLROnPlateau(
            self.optimizer, **hp.scheduler)
        self.max_epochs = hp.n_epochs

        self.writer: Optional[CustomWriter] = None

        self.valid_eval_sample: Dict[str, Any] = dict()

        # if hp.model['final_avg']:
        #     len_weight = hp.repeat_train
        # else:
        #     len_weight = hp.model['depth'] * hp.repeat_train
        len_weight = hp.repeat_train
        self.loss_weight = torch.tensor(
            [1. / i for i in range(len_weight, 0, -1)],
            device=self.out_device,
        )
        self.loss_weight /= self.loss_weight.sum()

        # Load State Dict
        if path_state_dict:
            st_model, st_optim, st_sched = torch.load(
                path_state_dict, map_location=self.in_device)
            try:
                self.module.load_state_dict(st_model)
                self.optimizer.load_state_dict(st_optim)
                self.scheduler.load_state_dict(st_sched)
            except:
                raise Exception('The model is different from the state dict.')

        path_summary = hp.logdir / 'summary.txt'
        if not path_summary.exists():
            # print_to_file(
            #     path_summary,
            #     summary,
            #     (self.model, hp.dummy_input_size),
            #     dict(device=self.str_device[:4])
            # )
            with path_summary.open('w') as f:
                f.write('\n')
            with (hp.logdir / 'hparams.txt').open('w') as f:
                f.write(repr(hp))

    def __init_device(self, device, out_device):
        """

        :type device: Union[int, str, Sequence]
        :type out_device: Union[int, str, Sequence]
        :return:
        """
        if device == 'cpu':
            self.in_device = torch.device('cpu')
            self.out_device = torch.device('cpu')
            self.str_device = 'cpu'
            return

        # device type: List[int]
        if type(device) == int:
            device = [device]
        elif type(device) == str:
            device = [int(device.replace('cuda:', ''))]
        else:  # sequence of devices
            if type(device[0]) != int:
                device = [int(d.replace('cuda:', '')) for d in device]

        self.in_device = torch.device(f'cuda:{device[0]}')

        if len(device) > 1:
            if type(out_device) == int:
                self.out_device = torch.device(f'cuda:{out_device}')
            else:
                self.out_device = torch.device(out_device)
            self.str_device = ', '.join([f'cuda:{d}' for d in device])

            self.model = nn.DataParallel(self.model,
                                         device_ids=device,
                                         output_device=self.out_device)
        else:
            self.out_device = self.in_device
            self.str_device = str(self.in_device)

        self.model.cuda(self.in_device)
        self.criterion.cuda(self.out_device)

        torch.cuda.set_device(self.in_device)

    def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
        # B, F, T, C
        x = data['x']
        mag = data['y_mag']
        max_length = max(data['length'])
        y = data['y']

        x = x.to(self.in_device, non_blocking=True)
        mag = mag.to(self.in_device, non_blocking=True)
        y = y.to(self.out_device, non_blocking=True)

        return x, mag, max_length, y

    @torch.no_grad()
    def postprocess(self, output: Tensor, residual: Tensor, Ts: ndarray,
                    idx: int,
                    dataset: ComplexSpecDataset) -> Dict[str, ndarray]:
        dict_one = dict(out=output, res=residual)
        for key in dict_one:
            one = dict_one[key][idx, :, :, :Ts[idx]]
            one = one.permute(1, 2, 0).contiguous()  # F, T, 2

            one = one.cpu().numpy().view(dtype=np.complex64)  # F, T, 1
            dict_one[key] = one

        return dict_one

    def calc_loss(self, out_blocks: Tensor, y: Tensor,
                  T_ys: Sequence[int]) -> Tensor:
        """
        out_blocks: B, depth, C, F, T
        y: B, C, F, T
        """

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            loss_no_red = self.criterion(out_blocks, y.unsqueeze(1))
        loss_blocks = torch.zeros(out_blocks.shape[1], device=y.device)
        for T, loss_batch in zip(T_ys, loss_no_red):
            loss_blocks += torch.mean(loss_batch[..., :T], dim=(1, 2, 3))

        if len(loss_blocks) == 1:
            loss = loss_blocks.squeeze()
        else:
            loss = loss_blocks @ self.loss_weight
        return loss

    @torch.no_grad()
    def should_stop(self, loss_valid, epoch):
        if epoch == self.max_epochs - 1:
            return True
        self.scheduler.step(loss_valid)
        # if self.scheduler.t_epoch == 0:  # if it is restarted now
        #     # if self.loss_last_restart < loss_valid:
        #     #     return True
        #     if self.loss_last_restart * hp.threshold_stop < loss_valid:
        #         self.max_epochs = epoch + self.scheduler.restart_period + 1
        #     self.loss_last_restart = loss_valid

    def train(self,
              loader_train: DataLoader,
              loader_valid: DataLoader,
              logdir: Path,
              first_epoch=0):
        self.writer = CustomWriter(str(logdir),
                                   group='train',
                                   purge_step=first_epoch)

        # Start Training
        for epoch in range(first_epoch, hp.n_epochs):
            self.writer.add_scalar('loss/lr',
                                   self.optimizer.param_groups[0]['lr'], epoch)
            print()
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}',
                        postfix='[]',
                        dynamic_ncols=True)
            avg_loss = AverageMeter(float)
            avg_grad_norm = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                T_ys = data['T_ys']

                # forward
                output_loss, _, _ = self.model(
                    x, mag, max_length, repeat=hp.repeat_train)  # B, C, F, T

                loss = self.calc_loss(output_loss, y, T_ys)

                # backward
                self.optimizer.zero_grad()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), hp.thr_clip_grad)

                self.optimizer.step()

                # print
                avg_loss.update(loss.item(), len(T_ys))
                pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')
                avg_grad_norm.update(grad_norm)

            self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch)
            self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(),
                                   epoch)

            # Validation
            # loss_valid = self.validate(loader_valid, logdir, epoch)
            loss_valid = self.validate(loader_valid,
                                       logdir,
                                       epoch,
                                       repeat=hp.repeat_train)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save((
                    self.module.state_dict(),
                    self.optimizer.state_dict(),
                    self.scheduler.state_dict(),
                ), logdir / f'{epoch}.pt')

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()

    @torch.no_grad()
    def validate(self, loader: DataLoader, logdir: Path, epoch: int, repeat=1):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        suffix = f'_{repeat}' if repeat > 1 else ''

        self.model.eval()

        avg_loss = AverageMeter(float)

        pbar = tqdm(loader,
                    desc='validate ',
                    postfix='[0]',
                    dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output_loss, output, residual = self.model(x,
                                                       mag,
                                                       max_length,
                                                       repeat=repeat)

            # loss
            loss = self.calc_loss(output_loss, y, T_ys)
            avg_loss.update(loss.item(), len(T_ys))

            # print
            pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            # write summary
            if i_iter == 0:
                # F, T, C
                if not self.valid_eval_sample:
                    self.valid_eval_sample = ComplexSpecDataset.decollate_padded(
                        data, 0)

                out_one = self.postprocess(output, residual, T_ys, 0,
                                           loader.dataset)

                # ComplexSpecDataset.save_dirspec(
                #     logdir / hp.form_result.format(epoch),
                #     **self.valid_eval_sample, **out_one
                # )

                if not self.writer.reused_sample:
                    one_sample = self.valid_eval_sample
                else:
                    one_sample = dict()

                self.writer.write_one(epoch,
                                      **one_sample,
                                      **out_one,
                                      suffix=suffix)

        self.writer.add_scalar(f'loss/valid{suffix}', avg_loss.get_average(),
                               epoch)

        self.model.train()

        return avg_loss.get_average()

    @torch.no_grad()
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            for sub in self.module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        cnt_sample = 0
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            if 0 < hp.n_save_block_outs == i_iter:
                break
            _, output, residual = self.model(x,
                                             mag,
                                             max_length,
                                             repeat=hp.repeat_test)

            # write summary
            for i_b in range(len(T_ys)):
                i_sample = cnt_sample + i_b
                one_sample = ComplexSpecDataset.decollate_padded(
                    data, i_b)  # F, T, C

                out_one = self.postprocess(output, residual, T_ys, i_b,
                                           loader.dataset)

                ComplexSpecDataset.save_dirspec(
                    logdir / hp.form_result.format(i_sample), **one_sample,
                    **out_one)

                measure = self.writer.write_one(i_sample,
                                                **out_one,
                                                **one_sample,
                                                suffix=f'_{hp.repeat_test}')
                if avg_measure is None:
                    avg_measure = AverageMeter(init_value=measure)
                else:
                    avg_measure.update(measure)
                # print
                # str_measure = arr2str(measure).replace('\n', '; ')
                # pbar.write(str_measure)
            cnt_sample += len(T_ys)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text(f'{group}/Average Measure/Proposed',
                             str(avg_measure[0]))
        self.writer.add_text(f'{group}/Average Measure/Reverberant',
                             str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')
Example #2
0
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            for sub in self.module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        cnt_sample = 0
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            if 0 < hp.n_save_block_outs == i_iter:
                break
            _, output, residual = self.model(x,
                                             mag,
                                             max_length,
                                             repeat=hp.repeat_test)

            # write summary
            for i_b in range(len(T_ys)):
                i_sample = cnt_sample + i_b
                one_sample = ComplexSpecDataset.decollate_padded(
                    data, i_b)  # F, T, C

                out_one = self.postprocess(output, residual, T_ys, i_b,
                                           loader.dataset)

                ComplexSpecDataset.save_dirspec(
                    logdir / hp.form_result.format(i_sample), **one_sample,
                    **out_one)

                measure = self.writer.write_one(i_sample,
                                                **out_one,
                                                **one_sample,
                                                suffix=f'_{hp.repeat_test}')
                if avg_measure is None:
                    avg_measure = AverageMeter(init_value=measure)
                else:
                    avg_measure.update(measure)
                # print
                # str_measure = arr2str(measure).replace('\n', '; ')
                # pbar.write(str_measure)
            cnt_sample += len(T_ys)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text(f'{group}/Average Measure/Proposed',
                             str(avg_measure[0]))
        self.writer.add_text(f'{group}/Average Measure/Reverberant',
                             str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')
class Trainer:
    def __init__(self, path_state_dict=''):
        self.model_name = hp.model_name
        module = eval(hp.model_name)

        self.model = module(**getattr(hp, hp.model_name))
        self.criterion = nn.MSELoss(reduction='none')
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=hp.learning_rate,
            weight_decay=hp.weight_decay,
        )

        self.__init_device(hp.device, hp.out_device)

        self.scheduler: Optional[CosineLRWithRestarts] = None
        self.max_epochs = hp.n_epochs
        self.loss_last_restart = float('inf')

        self.writer: Optional[CustomWriter] = None

        # a sample in validation set for evaluation
        self.valid_eval_sample: Dict[str, Any] = dict()

        # Load State Dict
        if path_state_dict:
            st_model, st_optim = torch.load(path_state_dict,
                                            map_location=self.in_device)
            try:
                if hasattr(self.model, 'module'):
                    self.model.module.load_state_dict(st_model)
                else:
                    self.model.load_state_dict(st_model)
                self.optimizer.load_state_dict(st_optim)
            except:
                raise Exception('The model is different from the state dict.')

        path_summary = hp.logdir / 'summary.txt'
        if not path_summary.exists():
            print_to_file(path_summary, summary,
                          (self.model, hp.dummy_input_size),
                          dict(device=self.str_device[:4]))
            with (hp.logdir / 'hparams.txt').open('w') as f:
                f.write(repr(hp))

    def __init_device(self, device, out_device):
        """

        :type device: Union[int, str, Sequence]
        :type out_device: Union[int, str, Sequence]
        :return:
        """
        if device == 'cpu':
            self.in_device = torch.device('cpu')
            self.out_device = torch.device('cpu')
            self.str_device = 'cpu'
            return

        # device type: List[int]
        if type(device) == int:
            device = [device]
        elif type(device) == str:
            device = [int(device.replace('cuda:', ''))]
        else:  # sequence of devices
            if type(device[0]) != int:
                device = [int(d.replace('cuda:', '')) for d in device]

        self.in_device = torch.device(f'cuda:{device[0]}')

        if len(device) > 1:
            if type(out_device) == int:
                self.out_device = torch.device(f'cuda:{out_device}')
            else:
                self.out_device = torch.device(out_device)
            self.str_device = ', '.join([f'cuda:{d}' for d in device])

            self.model = nn.DataParallel(self.model,
                                         device_ids=device,
                                         output_device=self.out_device)
        else:
            self.out_device = self.in_device
            self.str_device = str(self.in_device)

        self.model.cuda(self.in_device)
        self.criterion.cuda(self.out_device)

        torch.cuda.set_device(self.in_device)

    def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
        # B, F, T, C
        x = data['normalized_x']
        y = data['normalized_y']

        x = x.to(self.in_device, non_blocking=True)
        y = y.to(self.out_device, non_blocking=True)

        return x, y

    @torch.no_grad()
    def postprocess(self, output: Tensor, Ts: ndarray, idx: int,
                    dataset: DirSpecDataset) -> Dict[str, ndarray]:
        one = output[idx, :, :, :Ts[idx]]
        if self.model_name.startswith('UNet'):
            one = one.permute(1, 2, 0)  # F, T, C

        one = dataset.denormalize_(y=one)
        one = one.cpu().numpy()

        return dict(out=one)

    def calc_loss(self, output: Tensor, y: Tensor,
                  T_ys: Sequence[int]) -> Tensor:
        loss_batch = self.criterion(output, y)
        loss = torch.zeros(1, device=loss_batch.device)
        for T, loss_sample in zip(T_ys, loss_batch):
            loss += torch.sum(loss_sample[:, :, :T]) / T

        return loss

    @torch.no_grad()
    def should_stop(self, loss_valid, epoch):
        if epoch == self.max_epochs - 1:
            return True
        self.scheduler.step()
        # early stopping criterion
        # if self.scheduler.t_epoch == 0:  # if it is restarted now
        #     # if self.loss_last_restart < loss_valid:
        #     #     return True
        #     if self.loss_last_restart * hp.threshold_stop < loss_valid:
        #         self.max_epochs = epoch + self.scheduler.restart_period + 1
        #     self.loss_last_restart = loss_valid

    def train(self,
              loader_train: DataLoader,
              loader_valid: DataLoader,
              logdir: Path,
              first_epoch=0):
        # Learning Rate Scheduler
        self.scheduler = CosineLRWithRestarts(self.optimizer,
                                              batch_size=hp.batch_size,
                                              epoch_size=len(
                                                  loader_train.dataset),
                                              last_epoch=first_epoch - 1,
                                              **hp.scheduler)
        self.scheduler.step()

        self.writer = CustomWriter(str(logdir),
                                   group='train',
                                   purge_step=first_epoch)

        # write DNN structure to tensorboard. not properly work in PyTorch 1.3
        # self.writer.add_graph(
        #     self.model.module if hasattr(self.model, 'module') else self.model,
        #     torch.zeros(1, hp.UNet['ch_in'], 256, 256, device=self.in_device),
        # )

        # Start Training
        for epoch in range(first_epoch, hp.n_epochs):

            print()
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}',
                        postfix='[]',
                        dynamic_ncols=True)
            avg_loss = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                x, y = self.preprocess(data)  # B, C, F, T
                T_ys = data['T_ys']

                # forward
                output = self.model(x)[..., :y.shape[-1]]  # B, C, F, T

                loss = self.calc_loss(output, y, T_ys)

                # backward
                self.optimizer.zero_grad()
                loss.backward()

                self.optimizer.step()
                self.scheduler.batch_step()

                # print
                avg_loss.update(loss.item(), len(T_ys))
                pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch)

            # Validation
            loss_valid = self.validate(loader_valid, logdir, epoch)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save((
                    self.model.module.state_dict(),
                    self.optimizer.state_dict(),
                ), logdir / f'{epoch}.pt')

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()

    @torch.no_grad()
    def validate(self, loader: DataLoader, logdir: Path, epoch: int):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """

        self.model.eval()

        avg_loss = AverageMeter(float)

        pbar = tqdm(loader,
                    desc='validate ',
                    postfix='[0]',
                    dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output = self.model(x)[..., :y.shape[-1]]

            # loss
            loss = self.calc_loss(output, y, T_ys)
            avg_loss.update(loss.item(), len(T_ys))

            # print
            pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            # write summary
            if i_iter == 0:
                # F, T, C
                if epoch == 0:
                    one_sample = DirSpecDataset.decollate_padded(data, 0)
                else:
                    one_sample = dict()

                out_one = self.postprocess(output, T_ys, 0, loader.dataset)

                # DirSpecDataset.save_dirspec(
                #     logdir / hp.form_result.format(epoch),
                #     **one_sample, **out_one
                # )

                self.writer.write_one(epoch, **one_sample, **out_one)

        self.writer.add_scalar('loss/valid', avg_loss.get_average(), epoch)

        self.model.train()

        return avg_loss.get_average()

    @torch.no_grad()
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            """ save forward propagation data

            """
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        # register hook to save output of each block
        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            if isinstance(self.model, nn.DataParallel):
                module = self.model.module
            else:
                module = self.model
            for sub in module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            if 0 < hp.n_save_block_outs == i_iter:
                break
            output = self.model(x)  # [..., :y.shape[-1]]

            # write summary
            one_sample = DirSpecDataset.decollate_padded(data, 0)  # F, T, C
            out_one = self.postprocess(output, T_ys, 0, loader.dataset)

            # DirSpecDataset.save_dirspec(
            #     logdir / hp.form_result.format(i_iter),
            #     **one_sample, **out_one
            # )

            measure = self.writer.write_one(
                i_iter,
                eval_with_y_ph=hp.eval_with_y_ph,
                **out_one,
                **one_sample,
            )
            if avg_measure is None:
                avg_measure = AverageMeter(init_value=measure,
                                           init_count=len(T_ys))
            else:
                avg_measure.update(measure)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text('Average Measure/Proposed', str(avg_measure[0]))
        self.writer.add_text('Average Measure/Reverberant',
                             str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')

    @torch.no_grad()
    def save_result(self, loader: DataLoader, logdir: Path):
        """ save results in npz files without evaluation for deep griffin-lim algorithm

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """

        import numpy as np

        self.model.eval()

        # avg_loss = AverageMeter(float)

        pbar = tqdm(loader, desc='save ', dynamic_ncols=True)
        i_cum = 0
        for i_iter, data in enumerate(pbar):
            # get data
            x, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output = self.model(x)[..., :y.shape[-1]]  # B, C, F, T

            output = output.permute(0, 2, 3, 1)  # B, F, T, C
            out_denorm = loader.dataset.denormalize_(y=output).cpu().numpy()
            np.maximum(out_denorm, 0, out=out_denorm)
            out_denorm = out_denorm.squeeze()  # B, F, T

            # B, F, T
            x_phase = data['x_phase'][..., :y.shape[-1], 0].numpy()
            y_phase = data['y_phase'].numpy().squeeze()
            out_x_ph = out_denorm * np.exp(1j * x_phase)
            out_y_ph = out_denorm * np.exp(1j * y_phase)

            for i_b, T, in enumerate(T_ys):
                # F, T
                noisy = np.ascontiguousarray(out_x_ph[i_b, ..., :T])
                clean = np.ascontiguousarray(out_y_ph[i_b, ..., :T])
                mag = np.ascontiguousarray(out_denorm[i_b, ..., :T])
                length = hp.n_fft + hp.l_hop * (T - 1) - hp.n_fft // 2 * 2

                spec_data = dict(spec_noisy=noisy,
                                 spec_clean=clean,
                                 mag_clean=mag,
                                 length=length)
                np.savez(str(logdir / f'{i_cum + i_b}.npz'), **spec_data)
            i_cum += len(T_ys)

        self.model.train()
Example #4
0
    def train(self,
              loader_train: DataLoader,
              loader_valid: DataLoader,
              logdir: Path,
              first_epoch=0):
        self.writer = CustomWriter(str(logdir),
                                   group='train',
                                   purge_step=first_epoch)

        # Start Training
        for epoch in range(first_epoch, hp.n_epochs):
            self.writer.add_scalar('loss/lr',
                                   self.optimizer.param_groups[0]['lr'], epoch)
            print()
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}',
                        postfix='[]',
                        dynamic_ncols=True)
            avg_loss = AverageMeter(float)
            avg_grad_norm = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                T_ys = data['T_ys']

                # forward
                output_loss, _, _ = self.model(
                    x, mag, max_length, repeat=hp.repeat_train)  # B, C, F, T

                loss = self.calc_loss(output_loss, y, T_ys)

                # backward
                self.optimizer.zero_grad()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), hp.thr_clip_grad)

                self.optimizer.step()

                # print
                avg_loss.update(loss.item(), len(T_ys))
                pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')
                avg_grad_norm.update(grad_norm)

            self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch)
            self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(),
                                   epoch)

            # Validation
            # loss_valid = self.validate(loader_valid, logdir, epoch)
            loss_valid = self.validate(loader_valid,
                                       logdir,
                                       epoch,
                                       repeat=hp.repeat_train)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save((
                    self.module.state_dict(),
                    self.optimizer.state_dict(),
                    self.scheduler.state_dict(),
                ), logdir / f'{epoch}.pt')

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()
    def train(self,
              loader_train: DataLoader,
              loader_valid: DataLoader,
              logdir: Path,
              first_epoch=0):
        # Learning Rate Scheduler
        self.scheduler = CosineLRWithRestarts(self.optimizer,
                                              batch_size=hp.batch_size,
                                              epoch_size=len(
                                                  loader_train.dataset),
                                              last_epoch=first_epoch - 1,
                                              **hp.scheduler)
        self.scheduler.step()

        self.writer = CustomWriter(str(logdir),
                                   group='train',
                                   purge_step=first_epoch)

        # write DNN structure to tensorboard. not properly work in PyTorch 1.3
        # self.writer.add_graph(
        #     self.model.module if hasattr(self.model, 'module') else self.model,
        #     torch.zeros(1, hp.UNet['ch_in'], 256, 256, device=self.in_device),
        # )

        # Start Training
        for epoch in range(first_epoch, hp.n_epochs):

            print()
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}',
                        postfix='[]',
                        dynamic_ncols=True)
            avg_loss = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                x, y = self.preprocess(data)  # B, C, F, T
                T_ys = data['T_ys']

                # forward
                output = self.model(x)[..., :y.shape[-1]]  # B, C, F, T

                loss = self.calc_loss(output, y, T_ys)

                # backward
                self.optimizer.zero_grad()
                loss.backward()

                self.optimizer.step()
                self.scheduler.batch_step()

                # print
                avg_loss.update(loss.item(), len(T_ys))
                pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch)

            # Validation
            loss_valid = self.validate(loader_valid, logdir, epoch)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save((
                    self.model.module.state_dict(),
                    self.optimizer.state_dict(),
                ), logdir / f'{epoch}.pt')

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            """ save forward propagation data

            """
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        # register hook to save output of each block
        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            if isinstance(self.model, nn.DataParallel):
                module = self.model.module
            else:
                module = self.model
            for sub in module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            if 0 < hp.n_save_block_outs == i_iter:
                break
            output = self.model(x)  # [..., :y.shape[-1]]

            # write summary
            one_sample = DirSpecDataset.decollate_padded(data, 0)  # F, T, C
            out_one = self.postprocess(output, T_ys, 0, loader.dataset)

            # DirSpecDataset.save_dirspec(
            #     logdir / hp.form_result.format(i_iter),
            #     **one_sample, **out_one
            # )

            measure = self.writer.write_one(
                i_iter,
                eval_with_y_ph=hp.eval_with_y_ph,
                **out_one,
                **one_sample,
            )
            if avg_measure is None:
                avg_measure = AverageMeter(init_value=measure,
                                           init_count=len(T_ys))
            else:
                avg_measure.update(measure)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text('Average Measure/Proposed', str(avg_measure[0]))
        self.writer.add_text('Average Measure/Reverberant',
                             str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        if self.writer is None:
            self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()
        depth = hp.model['depth']

        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            for sub in self.module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        ##pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        cnt_sample = 0
        for i_iter, data in enumerate(loader):

            sampleDict = {}
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T

            if hp.noisy_init:
                x = torch.normal(0, 1, x.shape).cuda(self.in_device)

            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            # if 0 < hp.n_save_block_outs == i_iter:
            #     break
            repeats = 1

            for _ in range(3):
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=1,
                                                 train_step=1)  ##warn up!
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)

            while repeats <= hp.repeat_test:
                stime = ms()
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats,
                                                 train_step=1)
                avg_measure = AverageMeter()
                avg_measure2 = AverageMeter()

                etime = ms(stime)
                speed = (max_length / hp.fs) * len(T_ys) / (etime / 1000)
                ##print("degli: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed))
                ##self.writer.add_scalar("Test Performance/degli", speed, repeats)
                # write summary
                for i_b in tqdm(range(len(T_ys)),
                                desc="degli, %d repeats" % repeats,
                                dynamic_ncols=True):
                    i_sample = cnt_sample + i_b

                    if not i_b in sampleDict:
                        one_sample = ComplexSpecDataset.decollate_padded(
                            data, i_b)
                        reused_sample, result_eval_glim = self.writer.write_zero(
                            0, i_b, **one_sample, suffix="Base stats")
                        sampleDict[i_b] = (reused_sample, result_eval_glim)

                    sampleItem = sampleDict[i_b]
                    reused_sample = sampleItem[0]
                    result_eval_glim = sampleItem[1]

                    out_one = self.postprocess(output, residual, T_ys, i_b,
                                               loader.dataset)

                    # ComplexSpecDataset.save_dirspec(
                    #     logdir / hp.form_result.format(i_sample),
                    #     **one_sample, **out_one
                    # )

                    measure = self.writer.write_one(repeats,
                                                    i_b,
                                                    result_eval_glim,
                                                    reused_sample,
                                                    **out_one,
                                                    suffix="3_deGLI")

                    avg_measure.update(measure)

                stime = ms()
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)

                etime = ms(stime)
                speed = (1000 * max_length / hp.fs) * len(T_ys) / (etime)
                ##print("pure gla: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed))
                ##self.writer.add_scalar("Test Performance/gla", speed, repeats)

                # write summary
                for i_b in tqdm(range(len(T_ys)),
                                desc="GLA, %d repeats" % repeats,
                                dynamic_ncols=True):
                    i_sample = cnt_sample + i_b
                    sampleItem = sampleDict[i_b]
                    reused_sample = sampleItem[0]
                    result_eval_glim = sampleItem[1]
                    out_one = self.postprocess(output, None, T_ys, i_b,
                                               loader.dataset)
                    measure = self.writer.write_one(repeats,
                                                    i_b,
                                                    result_eval_glim,
                                                    reused_sample,
                                                    **out_one,
                                                    suffix="4_GLA")
                    avg_measure2.update(measure)

                cnt_sample += len(T_ys)

                self.writer.add_scalar(f'STOI/Average Measure/deGLI',
                                       avg_measure.get_average()[0, 0],
                                       repeats * depth)
                self.writer.add_scalar(f'STOI/Average Measure/GLA',
                                       avg_measure2.get_average()[0, 0],
                                       repeats * depth)
                self.writer.add_scalar(f'STOI/Average Measure/deGLI_semilogx',
                                       avg_measure.get_average()[0, 0],
                                       int(repeats * depth).bit_length())
                self.writer.add_scalar(f'STOI/Average Measure/GLA_semilogx',
                                       avg_measure2.get_average()[0, 0],
                                       int(repeats * depth).bit_length())

                self.writer.add_scalar(f'PESQ/Average Measure/deGLI',
                                       avg_measure.get_average()[0, 1],
                                       repeats * depth)
                self.writer.add_scalar(f'PESQ/Average Measure/GLA',
                                       avg_measure2.get_average()[0, 1],
                                       repeats * depth)
                self.writer.add_scalar(f'PESQ/Average Measure/deGLI_semilogx',
                                       avg_measure.get_average()[0, 1],
                                       int(repeats * depth).bit_length())
                self.writer.add_scalar(f'PESQ/Average Measure/GLA_semilogx',
                                       avg_measure2.get_average()[0, 1],
                                       int(repeats * depth).bit_length())

                repeats = repeats * 2
            break
        self.model.train()

        self.writer.close()  # Explicitly close
    def speedtest(self, loader: DataLoader, logdir: Path):
        group = logdir.name.split('_')[0]

        if self.writer is None:
            self.writer = CustomWriter(str(logdir), group=group)

        depth = hp.model['depth']

        ##pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        repeats = 1
        while repeats * depth <= hp.repeat_test:

            pbar = tqdm(loader,
                        desc="degli performance, %d repeats" % repeats,
                        dynamic_ncols=True)

            stime = time()

            tot_len = 0
            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats,
                                                 train_step=1)
                tot_len = tot_len + max_length * x.size(0)

            etime = int(time() - stime)
            speed = (tot_len / hp.sampling_rate) / (etime)
            self.writer.add_scalar("Test Performance/degli", speed,
                                   repeats * depth)
            self.writer.add_scalar("Test Performance/degli_semilogx", speed,
                                   int(repeats * depth).bit_length())

            repeats = repeats * 2

        repeats = 1
        while repeats * depth <= hp.repeat_test:

            stime = time()
            pbar = tqdm(loader,
                        desc="GLA performance, %d repeats" % repeats,
                        dynamic_ncols=True)
            tot_len = 0

            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)
                tot_len = tot_len + max_length * x.size(0)

            etime = int(time() - stime)
            speed = (tot_len / hp.sampling_rate) / (etime)
            self.writer.add_scalar("Test Performance/gla", speed,
                                   repeats * depth)
            self.writer.add_scalar("Test Performance/gla_semilogx", speed,
                                   int(repeats * depth).bit_length())

            repeats = repeats * 2

        self.model.train()

        self.writer.close()  # Explicitly close
class Trainer:
    def __init__(self, path_state_dict=''):

        self.writer: Optional[CustomWriter] = None

        config = {
            'vanilla': hp.vanilla_model,
            "ed": hp.ed_model
        }[hp.model_type.lower()]

        self.model = DeGLI(self.writer, config, hp.model_type, hp.n_freq,
                           hp.use_fp16, **hp.model)

        count_parameters(self.model)

        self.criterion = nn.L1Loss(reduction='none')
        if hp.optimizer == "adam":
            self.optimizer = Adam(
                self.model.parameters(),
                lr=hp.learning_rate,
                weight_decay=hp.weight_decay,
            )
        elif hp.optimizer == "sgd":
            self.optimizer = SGD(
                self.model.parameters(),
                lr=hp.learning_rate,
                weight_decay=hp.weight_decay,
            )
        elif hp.optimizer == "radam":
            self.optimizer = RAdam(
                self.model.parameters(),
                lr=hp.learning_rate,
                weight_decay=hp.weight_decay,
            )
        elif hp.optimizer == "novograd":
            self.optimizer = NovoGrad(self.model.parameters(),
                                      lr=hp.learning_rate,
                                      weight_decay=hp.weight_decay)
        elif hp.optimizer == "sm3":
            raise NameError('sm3 not implemented')
        else:
            raise NameError('optimizer not implemented')

        self.module = self.model

        # self.optimizer = SGD(self.model.parameters(),
        #                       lr=hp.learning_rate,
        #                       weight_decay=hp.weight_decay,
        #                       )

        self.__init_device(hp.device, hp.out_device)

        if hp.use_fp16:
            from apex import amp
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')

        self.reused_sample = None
        self.result_eval_glim = None

        ##if  hp.optimizer == "novograd":
        ##    self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, 744*3 ,1e-4)
        ##else:
        self.scheduler = lr_scheduler.ReduceLROnPlateau(
            self.optimizer, **hp.scheduler)
        self.max_epochs = hp.n_epochs

        self.valid_eval_sample: Dict[str, Any] = dict()

        # if hp.model['final_avg']:
        #     len_weight = hp.repeat_train
        # else:
        #     len_weight = hp.model['depth'] * hp.repeat_train
        len_weight = hp.repeat_train
        self.loss_weight = torch.tensor(
            [1. / i for i in range(len_weight, 0, -1)],
            device=self.out_device,
        )
        self.loss_weight /= self.loss_weight.sum()

        # Load State Dict
        if path_state_dict:
            st_model, st_optim, st_sched = torch.load(
                path_state_dict, map_location=self.in_device)
            try:
                self.module.load_state_dict(st_model)
                self.optimizer.load_state_dict(st_optim)
                self.scheduler.load_state_dict(st_sched)
            except:
                raise Exception('The model is different from the state dict.')

        path_summary = hp.logdir / 'summary.txt'
        if not path_summary.exists():
            # print_to_file(
            #     path_summary,
            #     summary,
            #     (self.model, hp.dummy_input_size),
            #     dict(device=self.str_device[:4])
            # )
            with path_summary.open('w') as f:
                f.write('\n')
            with (hp.logdir / 'hparams.txt').open('w') as f:
                f.write(repr(hp))

    def __init_device(self, device, out_device):
        """

        :type device: Union[int, str, Sequence]
        :type out_device: Union[int, str, Sequence]
        :return:
        """
        if device == 'cpu':
            self.in_device = torch.device('cpu')
            self.out_device = torch.device('cpu')
            self.str_device = 'cpu'
            return

        # device type: List[int]
        if type(device) == int:
            device = [device]
        elif type(device) == str:
            if device[0] == 'a':
                device = [x for x in range(torch.cuda.device_count())]
            else:
                device = [
                    int(d.replace('cuda:', '')) for d in device.split(",")
                ]
            print("Used devices = %s" % device)
        else:  # sequence of devices
            if type(device[0]) != int:
                device = [int(d.replace('cuda:', '')) for d in device]

        self.in_device = torch.device(f'cuda:{device[0]}')

        if len(device) > 1:
            if type(out_device) == int:
                self.out_device = torch.device(f'cuda:{out_device}')
            else:
                self.out_device = torch.device(out_device)
            self.out_device = 0
            self.str_device = ', '.join([f'cuda:{d}' for d in device])

            self.model = nn.DataParallel(self.model,
                                         device_ids=device,
                                         output_device=self.out_device)
        else:
            self.out_device = self.in_device
            self.str_device = str(self.in_device)

        self.model.cuda(self.in_device)
        self.criterion.cuda(self.out_device)

        torch.cuda.set_device(self.in_device)

    def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
        # B, F, T, C
        x = data['x']
        mag = data['y_mag']
        max_length = max(data['length'])
        y = data['y']

        x = x.to(self.in_device, non_blocking=True)
        mag = mag.to(self.in_device, non_blocking=True)
        y = y.to(self.out_device, non_blocking=True)

        return x, mag, max_length, y

    @torch.no_grad()
    def postprocess(self, output: Tensor, residual: Tensor, Ts: ndarray,
                    idx: int,
                    dataset: ComplexSpecDataset) -> Dict[str, ndarray]:
        dict_one = dict(out=output, res=residual)
        for key in dict_one:
            if dict_one[key] is None:
                continue
            one = dict_one[key][idx, :, :, :Ts[idx]]
            one = one.permute(1, 2, 0).contiguous()  # F, T, 2

            one = one.cpu().numpy().view(dtype=np.complex64)  # F, T, 1
            dict_one[key] = one

        return dict_one

    def calc_loss(self, out_blocks: Tensor, y: Tensor,
                  T_ys: Sequence[int]) -> Tensor:
        """
        out_blocks: B, depth, C, F, T
        y: B, C, F, T
        """

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            loss_no_red = self.criterion(out_blocks, y.unsqueeze(1))
        loss_blocks = torch.zeros(out_blocks.shape[1], device=y.device)
        for T, loss_batch in zip(T_ys, loss_no_red):
            loss_blocks += torch.mean(loss_batch[..., :T], dim=(1, 2, 3))

        if len(loss_blocks) == 1:
            loss = loss_blocks.squeeze()
        else:
            loss = loss_blocks @ self.loss_weight
        return loss

    @torch.no_grad()
    def should_stop(self, loss_valid, epoch):
        if epoch == self.max_epochs - 1:
            return True
        self.scheduler.step(loss_valid)
        # if self.scheduler.t_epoch == 0:  # if it is restarted now
        #     # if self.loss_last_restart < loss_valid:
        #     #     return True
        #     if self.loss_last_restart * hp.threshold_stop < loss_valid:
        #         self.max_epochs = epoch + self.scheduler.restart_period + 1
        #     self.loss_last_restart = loss_valid

    def train(self,
              loader_train: DataLoader,
              loader_valid: DataLoader,
              logdir: Path,
              first_epoch=0):
        self.writer = CustomWriter(str(logdir),
                                   group='train',
                                   purge_step=first_epoch)

        # Start Training
        step = 0

        loss_valid = self.validate(loader_valid,
                                   logdir,
                                   0,
                                   step,
                                   repeat=hp.repeat_train)

        for epoch in range(first_epoch, hp.n_epochs):

            self.writer.add_scalar('loss/lr',
                                   self.optimizer.param_groups[0]['lr'], epoch)
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}',
                        postfix='[]',
                        dynamic_ncols=True)
            avg_loss = AverageMeter(float)
            avg_grad_norm = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                T_ys = data['T_ys']
                # forward
                output_loss, _, _ = self.model(x,
                                               mag,
                                               max_length,
                                               repeat=hp.repeat_train,
                                               train_step=step)  # B, C, F, T
                step = step + 1

                loss = self.calc_loss(output_loss, y, T_ys)

                # backward
                self.optimizer.zero_grad()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), hp.thr_clip_grad)

                self.optimizer.step()

                # print
                # if np.any(np.isnan(loss.item())):
                #     raise NameError('Loss is Nan!')

                # for vname,var in self.model.named_parameters():
                #     if np.any(np.isnan(var.detach().cpu().numpy())):
                #         print("nan detected in %s " % vname)

                ##import pdb; pdb.set_trace()

                avg_loss.update(loss.item(), len(T_ys))
                pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')
                avg_grad_norm.update(grad_norm)

                if i_iter % 25 == 0:
                    self.writer.add_scalar('loss/train',
                                           avg_loss.get_average(),
                                           epoch * len(loader_train) + i_iter)
                    self.writer.add_scalar('loss/grad',
                                           avg_grad_norm.get_average(),
                                           epoch * len(loader_train) + i_iter)
                    avg_loss = AverageMeter(float)
                    avg_grad_norm = AverageMeter(float)

            # Validation
            # loss_valid = self.validate(loader_valid, logdir, epoch)
            loss_valid = self.validate(loader_valid,
                                       logdir,
                                       epoch + 1,
                                       step,
                                       repeat=hp.repeat_train)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save((
                    self.module.state_dict(),
                    self.optimizer.state_dict(),
                    self.scheduler.state_dict(),
                ), logdir / f'{epoch+1}.pt')

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()

    @torch.no_grad()
    def validate(self,
                 loader: DataLoader,
                 logdir: Path,
                 epoch: int,
                 step,
                 repeat=1):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        suffix = f'_{repeat}' if repeat > 1 else ''

        self.model.eval()
        stoi_cnt = 0
        stoi_cntX = 0
        stoi_iters = hp.stoi_iters
        stoi_iters_rate = hp.stoi_iters_rate

        avg_loss = AverageMeter(float)
        avg_measure = AverageMeter(float)
        pesq_avg_measure = AverageMeter(float)

        avg_measureX = AverageMeter(float)
        pesq_avg_measureX = AverageMeter(float)

        pbar = tqdm(loader,
                    desc='validate ',
                    postfix='[0]',
                    dynamic_ncols=True)

        num_iters = len(pbar)
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output_loss, output, residual = self.model(x,
                                                       mag,
                                                       max_length,
                                                       repeat=repeat,
                                                       train_step=step)

            # loss
            loss = self.calc_loss(output_loss, y, T_ys)
            avg_loss.update(loss.item(), len(T_ys))

            # print
            pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            # write summary
            # if i_iter == 0:
            #     if self.reused_sample is None:
            #         one_sample = ComplexSpecDataset.decollate_padded(data, i_iter)
            #         self.reused_sample, self.result_eval_glim = self.writer.write_zero(0, i_iter, **one_sample, suffix="Base stats")

            #     out_one = self.postprocess(output, residual, T_ys, i_iter, loader.dataset)
            #     self.writer.write_one(0, i_iter, self.result_eval_glim, self.reused_sample ,**out_one, suffix="deGLI")

            if stoi_cnt <= hp.num_stoi:
                ##import pdb; pdb.set_trace()
                for p in range(min(hp.num_stoi // num_iters, len(T_ys))):
                    y_wav = data['wav'][p]
                    out = self.postprocess(output, None, T_ys, p, None)['out']
                    out_wav = reconstruct_wave(out, n_sample=data['length'][p])

                    measure = calc_using_eval_module(y_wav, out_wav)
                    stoi = measure['STOI']
                    pesq_score = measure['PESQ']
                    avg_measure.update(stoi)
                    pesq_avg_measure.update(pesq_score)

                    stoi_cnt = stoi_cnt + 1

            if (stoi_iters > 0) and (epoch % stoi_iters_rate == 0):
                _, output, _ = self.model(x,
                                          mag,
                                          max_length,
                                          repeat=stoi_iters,
                                          train_step=step)

                if stoi_cntX <= hp.num_stoi:
                    ##import pdb; pdb.set_trace()
                    for p in range(min(hp.num_stoi // num_iters, len(T_ys))):
                        y_wav = data['wav'][p]
                        out = self.postprocess(output, None, T_ys, p,
                                               None)['out']
                        out_wav = reconstruct_wave(out,
                                                   n_sample=data['length'][p])

                        measure = calc_using_eval_module(y_wav, out_wav)
                        stoi = measure['STOI']
                        pesq_score = measure['PESQ']
                        avg_measureX.update(stoi)
                        pesq_avg_measureX.update(pesq_score)

                        stoi_cntX = stoi_cntX + 1

        self.writer.add_scalar(f'loss/valid', avg_loss.get_average(), epoch)
        self.writer.add_scalar(f'loss/STOI', avg_measure.get_average(), epoch)
        self.writer.add_scalar(f'loss/PESQ', pesq_avg_measure.get_average(),
                               epoch)

        if (stoi_iters > 0) and (epoch % stoi_iters_rate == 0):
            self.writer.add_scalar(f'loss/PESQ_X{stoi_iters}',
                                   pesq_avg_measureX.get_average(), epoch)
            self.writer.add_scalar(f'loss/STOI_X{stoi_iters}',
                                   avg_measureX.get_average(), epoch)

        self.model.train()

        return avg_loss.get_average()

    @torch.no_grad()
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        if self.writer is None:
            self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()
        depth = hp.model['depth']

        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            for sub in self.module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        ##pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        cnt_sample = 0
        for i_iter, data in enumerate(loader):

            sampleDict = {}
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T

            if hp.noisy_init:
                x = torch.normal(0, 1, x.shape).cuda(self.in_device)

            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            # if 0 < hp.n_save_block_outs == i_iter:
            #     break
            repeats = 1

            for _ in range(3):
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=1,
                                                 train_step=1)  ##warn up!
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)

            while repeats <= hp.repeat_test:
                stime = ms()
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats,
                                                 train_step=1)
                avg_measure = AverageMeter()
                avg_measure2 = AverageMeter()

                etime = ms(stime)
                speed = (max_length / hp.fs) * len(T_ys) / (etime / 1000)
                ##print("degli: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed))
                ##self.writer.add_scalar("Test Performance/degli", speed, repeats)
                # write summary
                for i_b in tqdm(range(len(T_ys)),
                                desc="degli, %d repeats" % repeats,
                                dynamic_ncols=True):
                    i_sample = cnt_sample + i_b

                    if not i_b in sampleDict:
                        one_sample = ComplexSpecDataset.decollate_padded(
                            data, i_b)
                        reused_sample, result_eval_glim = self.writer.write_zero(
                            0, i_b, **one_sample, suffix="Base stats")
                        sampleDict[i_b] = (reused_sample, result_eval_glim)

                    sampleItem = sampleDict[i_b]
                    reused_sample = sampleItem[0]
                    result_eval_glim = sampleItem[1]

                    out_one = self.postprocess(output, residual, T_ys, i_b,
                                               loader.dataset)

                    # ComplexSpecDataset.save_dirspec(
                    #     logdir / hp.form_result.format(i_sample),
                    #     **one_sample, **out_one
                    # )

                    measure = self.writer.write_one(repeats,
                                                    i_b,
                                                    result_eval_glim,
                                                    reused_sample,
                                                    **out_one,
                                                    suffix="3_deGLI")

                    avg_measure.update(measure)

                stime = ms()
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)

                etime = ms(stime)
                speed = (1000 * max_length / hp.fs) * len(T_ys) / (etime)
                ##print("pure gla: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed))
                ##self.writer.add_scalar("Test Performance/gla", speed, repeats)

                # write summary
                for i_b in tqdm(range(len(T_ys)),
                                desc="GLA, %d repeats" % repeats,
                                dynamic_ncols=True):
                    i_sample = cnt_sample + i_b
                    sampleItem = sampleDict[i_b]
                    reused_sample = sampleItem[0]
                    result_eval_glim = sampleItem[1]
                    out_one = self.postprocess(output, None, T_ys, i_b,
                                               loader.dataset)
                    measure = self.writer.write_one(repeats,
                                                    i_b,
                                                    result_eval_glim,
                                                    reused_sample,
                                                    **out_one,
                                                    suffix="4_GLA")
                    avg_measure2.update(measure)

                cnt_sample += len(T_ys)

                self.writer.add_scalar(f'STOI/Average Measure/deGLI',
                                       avg_measure.get_average()[0, 0],
                                       repeats * depth)
                self.writer.add_scalar(f'STOI/Average Measure/GLA',
                                       avg_measure2.get_average()[0, 0],
                                       repeats * depth)
                self.writer.add_scalar(f'STOI/Average Measure/deGLI_semilogx',
                                       avg_measure.get_average()[0, 0],
                                       int(repeats * depth).bit_length())
                self.writer.add_scalar(f'STOI/Average Measure/GLA_semilogx',
                                       avg_measure2.get_average()[0, 0],
                                       int(repeats * depth).bit_length())

                self.writer.add_scalar(f'PESQ/Average Measure/deGLI',
                                       avg_measure.get_average()[0, 1],
                                       repeats * depth)
                self.writer.add_scalar(f'PESQ/Average Measure/GLA',
                                       avg_measure2.get_average()[0, 1],
                                       repeats * depth)
                self.writer.add_scalar(f'PESQ/Average Measure/deGLI_semilogx',
                                       avg_measure.get_average()[0, 1],
                                       int(repeats * depth).bit_length())
                self.writer.add_scalar(f'PESQ/Average Measure/GLA_semilogx',
                                       avg_measure2.get_average()[0, 1],
                                       int(repeats * depth).bit_length())

                repeats = repeats * 2
            break
        self.model.train()

        self.writer.close()  # Explicitly close

        ##print()
        ##str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        ##print(f'Average: {str_avg_measure}')

    @torch.no_grad()
    def speedtest(self, loader: DataLoader, logdir: Path):
        group = logdir.name.split('_')[0]

        if self.writer is None:
            self.writer = CustomWriter(str(logdir), group=group)

        depth = hp.model['depth']

        ##pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        repeats = 1
        while repeats * depth <= hp.repeat_test:

            pbar = tqdm(loader,
                        desc="degli performance, %d repeats" % repeats,
                        dynamic_ncols=True)

            stime = time()

            tot_len = 0
            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats,
                                                 train_step=1)
                tot_len = tot_len + max_length * x.size(0)

            etime = int(time() - stime)
            speed = (tot_len / hp.sampling_rate) / (etime)
            self.writer.add_scalar("Test Performance/degli", speed,
                                   repeats * depth)
            self.writer.add_scalar("Test Performance/degli_semilogx", speed,
                                   int(repeats * depth).bit_length())

            repeats = repeats * 2

        repeats = 1
        while repeats * depth <= hp.repeat_test:

            stime = time()
            pbar = tqdm(loader,
                        desc="GLA performance, %d repeats" % repeats,
                        dynamic_ncols=True)
            tot_len = 0

            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)
                tot_len = tot_len + max_length * x.size(0)

            etime = int(time() - stime)
            speed = (tot_len / hp.sampling_rate) / (etime)
            self.writer.add_scalar("Test Performance/gla", speed,
                                   repeats * depth)
            self.writer.add_scalar("Test Performance/gla_semilogx", speed,
                                   int(repeats * depth).bit_length())

            repeats = repeats * 2

        self.model.train()

        self.writer.close()  # Explicitly close
    def train(self,
              loader_train: DataLoader,
              loader_valid: DataLoader,
              logdir: Path,
              first_epoch=0):
        self.writer = CustomWriter(str(logdir),
                                   group='train',
                                   purge_step=first_epoch)

        # Start Training
        step = 0

        loss_valid = self.validate(loader_valid,
                                   logdir,
                                   0,
                                   step,
                                   repeat=hp.repeat_train)

        for epoch in range(first_epoch, hp.n_epochs):

            self.writer.add_scalar('loss/lr',
                                   self.optimizer.param_groups[0]['lr'], epoch)
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}',
                        postfix='[]',
                        dynamic_ncols=True)
            avg_loss = AverageMeter(float)
            avg_grad_norm = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                T_ys = data['T_ys']
                # forward
                output_loss, _, _ = self.model(x,
                                               mag,
                                               max_length,
                                               repeat=hp.repeat_train,
                                               train_step=step)  # B, C, F, T
                step = step + 1

                loss = self.calc_loss(output_loss, y, T_ys)

                # backward
                self.optimizer.zero_grad()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), hp.thr_clip_grad)

                self.optimizer.step()

                # print
                # if np.any(np.isnan(loss.item())):
                #     raise NameError('Loss is Nan!')

                # for vname,var in self.model.named_parameters():
                #     if np.any(np.isnan(var.detach().cpu().numpy())):
                #         print("nan detected in %s " % vname)

                ##import pdb; pdb.set_trace()

                avg_loss.update(loss.item(), len(T_ys))
                pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')
                avg_grad_norm.update(grad_norm)

                if i_iter % 25 == 0:
                    self.writer.add_scalar('loss/train',
                                           avg_loss.get_average(),
                                           epoch * len(loader_train) + i_iter)
                    self.writer.add_scalar('loss/grad',
                                           avg_grad_norm.get_average(),
                                           epoch * len(loader_train) + i_iter)
                    avg_loss = AverageMeter(float)
                    avg_grad_norm = AverageMeter(float)

            # Validation
            # loss_valid = self.validate(loader_valid, logdir, epoch)
            loss_valid = self.validate(loader_valid,
                                       logdir,
                                       epoch + 1,
                                       step,
                                       repeat=hp.repeat_train)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save((
                    self.module.state_dict(),
                    self.optimizer.state_dict(),
                    self.scheduler.state_dict(),
                ), logdir / f'{epoch+1}.pt')

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()
Example #11
0
    def inspect(self, loader: DataLoader, logdir: Path):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        self.model.eval()

        os.makedirs(Path(logdir), exist_ok=True)
        self.writer = CustomWriter(str(logdir), group='test')

        ##import pdb; pdb.set_trace()
        num_filters = len(self.filters)

        avg_loss1 = AverageMeter(float)
        avg_lozz1 = AverageMeter(float)
        avg_loss2 = AverageMeter(float)
        avg_lozz2 = AverageMeter(float)

        avg_loss_tot = AverageMeter(float)
        avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
        avg_losses_base = [AverageMeter(float) for _ in range(num_filters) ]
        losses = [None] * num_filters
        losses_base = [None] * num_filters

        cnt = 0

        pbar = tqdm(enumerate(loader), desc='loss inspection', dynamic_ncols=True)

        for i_iter, data in pbar:

            ##import pdb; pdb.set_trace()
            y = self.preprocess(data)  # B, C, F, T
            x_mel = self.model.spec_to_mel(y) 

            z = self.model.mel_pseudo_inverse(x_mel)

            T_ys = data['T_ys']
            x = self.model(x_mel)  # B, C, F, T
            y_mel = self.model.spec_to_mel(x)     
            z_mel = self.model.spec_to_mel(y)

            loss1 = self.calc_loss(x, y, T_ys, self.criterion)
            lozz1 = self.calc_loss(z, y, T_ys, self.criterion)

            loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2)
            lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2)

            loss = loss1 + loss2*hp.l2_factor

            # for i,f in enumerate(self.filters):
            #     s = self.f_specs[i][1]
            #     losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s )
            #     loss = loss + losses[i]

            for i,(k,s) in enumerate(self.f_specs):
                losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s )
                losses_base[i] = self.calc_loss_smooth2(y,y,T_ys,k, s )

                loss = loss + losses[i]
            avg_loss1.update(loss1.item(), len(T_ys))
            avg_lozz1.update(lozz1.item(), len(T_ys))
            avg_loss2.update(loss2.item(), len(T_ys))
            avg_lozz2.update(lozz2.item(), len(T_ys))
            avg_loss_tot.update(loss.item(), len(T_ys))

            for j,l in enumerate(losses):
                avg_losses[j].update(l.item(), len(T_ys))
                
            for j,l in enumerate(losses_base):
                avg_losses_base[j].update(l.item(), len(T_ys))                
            # print
            ##pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}')

            # write summary

            if 0:
                for p in range(len(T_ys)):
                    _x = x[p,0,:,:T_ys[p]].cpu()
                    _y = y[p,0,:,:T_ys[p]].cpu()
                    _z = z[p,0,:,:T_ys[p]].cpu()
                    y_wav = data['wav'][p]

                    ymin = _y[_y > 0].min()
                    vmin, vmax = librosa.amplitude_to_db(np.array((ymin, _y.max())))
                    kwargs_fig = dict(vmin=vmin, vmax=vmax)


                    if hp.request_drawings:
                        fig_x = draw_spectrogram(_x, **kwargs_fig)
                        self.writer.add_figure(f'Audio/1_DNN_Output', fig_x, cnt)
                        fig_y = draw_spectrogram(_y, **kwargs_fig)
                        fig_z = draw_spectrogram(_z, **kwargs_fig)
                        self.writer.add_figure(f'Audio/0_Pseudo_Inverse', fig_z, cnt)
                        self.writer.add_figure(f'Audio/2_Real_Spectrogram', fig_y, cnt)

                    audio_x = self.audio_from_mag_spec(np.abs(_x.numpy()))
                    x_scale = np.abs(audio_x).max() / 0.5

                    self.writer.add_audio(f'LWS/1_DNN_Output',
                                torch.from_numpy(audio_x / x_scale),
                                cnt,
                                sample_rate=hp.sampling_rate)

                    audio_y = self.audio_from_mag_spec(_y.numpy())
                    audio_z = self.audio_from_mag_spec(_z.numpy())
                    
                    z_scale = np.abs(audio_z).max() / 0.5
                    y_scale = np.abs(audio_y).max() / 0.5

                    self.writer.add_audio(f'LWS/0_Pseudo_Inverse',
                                torch.from_numpy(audio_z / z_scale),
                                cnt,
                                sample_rate=hp.sampling_rate)


                    self.writer.add_audio(f'LWS/2_Real_Spectrogram',
                                torch.from_numpy(audio_y / y_scale),
                                cnt,
                                sample_rate=hp.sampling_rate)

                    ##import pdb; pdb.set_trace()

                    stoi_scores = {'0_Pseudo_Inverse'       : self.calc_stoi(y_wav, audio_z),
                                '1_DNN_Output'           : self.calc_stoi(y_wav, audio_x),
                                '2_Real_Spectrogram'     : self.calc_stoi(y_wav, audio_y)}

                    self.writer.add_scalars(f'LWS/STOI', stoi_scores, cnt )
                    # self.writer.add_scalar(f'STOI/0_Pseudo_Inverse_LWS', self.calc_stoi(y_wav, audio_z) , cnt)
                    # self.writer.add_scalar(f'STOI/1_DNN_Output_LWS', self.calc_stoi(y_wav, audio_x) , cnt)
                    # self.writer.add_scalar(f'STOI/2_Real_Spectrogram_LWS', self.calc_stoi(y_wav, audio_y) , cnt)
                    cnt = cnt + 1

        for j, avg_loss in enumerate(avg_losses):
            k = self.f_specs[j][0]
            s = self.f_specs[j][1]
            self.writer.add_scalar(f'inspect/losses_breakdown', avg_loss.get_average(), j)

        for j, avg_loss in enumerate(avg_losses_base):
            k = self.f_specs[j][0]
            s = self.f_specs[j][1]
            self.writer.add_scalar(f'inspect/losses_base_breakdown', avg_loss.get_average(), j)

        for j, avg_loss in enumerate(avg_losses):
            avg_loss2 = avg_losses_base[j]
            k = self.f_specs[j][0]
            s = self.f_specs[j][1]
            self.writer.add_scalar(f'inspect/losses_normalized_breakdown', avg_loss2.get_average() / avg_loss.get_average() , j)


        # self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch)
        # self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch)
        # self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch)
        # self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch)

        # for j, avg_loss in enumerate(avg_losses):
        #     k = self.f_specs[j][0]
        #     s = self.f_specs[j][1]
        #     self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch)
        # self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch)

        self.model.train()

        return 
Example #12
0
class Trainer:

    def __init__(self, path_state_dict=''):

        ##import pdb; pdb.set_trace()
        self.writer: Optional[CustomWriter] = None
        meltrans = create_mel_filterbank( hp.sampling_rate, hp.n_fft, fmin=hp.mel_fmin, fmax=hp.mel_fmax, n_mels=hp.mel_freq)

        self.model = melGen(self.writer, hp.n_freq, meltrans, hp.mel_generator)
        count_parameters(self.model)

        self.module = self.model

        self.lws_processor = lws.lws(hp.n_fft, hp.l_hop, mode='speech', perfectrec=False)

        self.prev_stoi_scores = {}
        self.base_stoi_scores = {}

        if hp.crit == "l1":
            self.criterion = nn.L1Loss(reduction='none')
        elif hp.crit == "l2":
            self.criterion = nn.L2Loss(reduction='none')
        else:
            print("Loss not implemented")
            return None

        self.criterion2 = nn.L1Loss(reduction='none')


        self.f_specs=  {0: [(5, 2),(15,5)],
                        1: [(5, 2)],
                        2: [(3 ,1)],
                        3: [(3 ,1),(5, 2 )],
                        4: [(3 ,1),(5, 2 ), ( 7,3 )  ],
                        5: [(15 ,5)],
                        6: [(3 ,1),(5, 2 ), ( 7,3 ), (15,5), (25,10)],
                        7: [(1 ,1)],
                        8: [(1 ,1), (3 ,1), (5, 2 ),(15 ,5),  ( 7,3 ),  (25,10), (9,4), (20,5), (5,3)   ]
                        }[hp.loss_mode]
                        


        self.filters = [gen_filter(k) for  k,s in self.f_specs]

        if hp.optimizer == "adam":
            self.optimizer = Adam(self.model.parameters(),
                                lr=hp.learning_rate,
                                weight_decay=hp.weight_decay,
                                )
        elif hp.optimizer == "sgd":
            self.optimizer = SGD(self.model.parameters(),
                                lr=hp.learning_rate,
                                weight_decay=hp.weight_decay,
                                )
        elif hp.optimizer == "radam":
            self.optimizer = RAdam(self.model.parameters(),
                                lr=hp.learning_rate,
                                weight_decay=hp.weight_decay,
                                )
        elif hp.optimizer == "novograd":
            self.optimizer = NovoGrad(self.model.parameters(), 
                                    lr=hp.learning_rate, 
                                    weight_decay=hp.weight_decay
                                    )
        elif hp.optimizer == "sm3":
            raise NameError('sm3 not implemented')
        else:
            raise NameError('optimizer not implemented')


        self.__init_device(hp.device)

        ##if  hp.optimizer == "novograd":
        ##    self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, 200 ,1e-5)
        ##else:
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, **hp.scheduler)

        self.max_epochs = hp.n_epochs


        self.valid_eval_sample: Dict[str, Any] = dict()

        # len_weight = hp.repeat_train
        # self.loss_weight = torch.tensor(
        #     [1./i for i in range(len_weight, 0, -1)],
        # )
        # self.loss_weight /= self.loss_weight.sum()

        # Load State Dict
        if path_state_dict:
            st_model, st_optim, st_sched = torch.load(path_state_dict, map_location=self.in_device)
            try:
                self.module.load_state_dict(st_model)
                self.optimizer.load_state_dict(st_optim)
                self.scheduler.load_state_dict(st_sched)
            except:
                raise Exception('The model is different from the state dict.')


        path_summary = hp.logdir / 'summary.txt'
        if not path_summary.exists():
            # print_to_file(
            #     path_summary,
            #     summary,
            #     (self.model, hp.dummy_input_size),
            #     dict(device=self.str_device[:4])
            # )
            with path_summary.open('w') as f:
                f.write('\n')
            with (hp.logdir / 'hparams.txt').open('w') as f:
                f.write(repr(hp))

    def __init_device(self, device):
        """

        :type device: Union[int, str, Sequence]
        :type out_device: Union[int, str, Sequence]
        :return:
        """


        # device type: List[int]
        if type(device) == int:
            device = [device]
        elif type(device) == str:
            if device[0] == 'a':
                device = [x for x in range(torch.cuda.device_count())]
            else:
                device = [int(d.replace('cuda:', '')) for d in device.split(",")]
            print("Used devices = %s" % device)
        else:  # sequence of devices
            if type(device[0]) != int:
                device = [int(d.replace('cuda:', '')) for d in device]
        self.num_workers = len(device)
        if len(device) > 1:
            self.model = nn.DataParallel(self.model, device_ids=device)

        self.in_device = torch.device(f'cuda:{device[0]}')
        torch.cuda.set_device(self.in_device)

        self.model.cuda()
        self.criterion.cuda()
        self.criterion2.cuda()
        self.filters = [f.cuda() for f in self.filters]

    def preprocess(self, data: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
        # B, F, T, C
        y = data['y']
        y = y.cuda()

        return y

    @torch.no_grad()
    def postprocess(self, output: Tensor, residual: Tensor, Ts: ndarray, idx: int,
                    dataset: ComplexSpecDataset) -> Dict[str, ndarray]:
        dict_one = dict(out=output, res=residual)
        for key in dict_one:
            if dict_one[key] is None:
                continue
            one = dict_one[key][idx, :, :, :Ts[idx]]
            one = one.permute(1, 2, 0).contiguous()  # F, T, 2

            one = one.cpu().numpy().view(dtype=np.complex64)  # F, T, 1
            dict_one[key] = one

        return dict_one

    def calc_loss(self, x: Tensor, y: Tensor, T_ys: Sequence[int], crit) -> Tensor:
        """
        out_blocks: B, depth, C, F, T
        y: B, C, F, T
        """
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            loss_no_red = crit(x, y)

        loss_blocks = torch.zeros(x.shape[1], device=y.device)

        tot =0 
        for T, loss_batch in zip(T_ys, loss_no_red):
            tot += T
            loss_blocks += torch.sum(loss_batch[..., :T])
        loss_blocks = loss_blocks / tot

        if len(loss_blocks) == 1:
            loss = loss_blocks.squeeze()
        else:
            loss = loss_blocks @ self.loss_weight
        return loss

    def calc_loss_smooth(self, _x: Tensor, _y: Tensor, T_ys: Sequence[int], filter, stride: int ,pad: int = 0) -> Tensor:
        """
        out_blocks: B, depth, C, F, T
        y: B, C, F, T
        """

        crit = self.criterion 
        x = F.conv2d(_x, filter, stride = stride)
        y = F.conv2d(_y, filter, stride = stride)

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            loss_no_red = crit(x, y)

        loss_blocks = torch.zeros(x.shape[1], device=y.device)

        tot =0 
        for T, loss_batch in zip(T_ys, loss_no_red):
            tot += T
            loss_blocks += torch.sum(loss_batch[..., :T])
        loss_blocks = loss_blocks / tot

        if len(loss_blocks) == 1:
            loss = loss_blocks.squeeze()
        else:
            loss = loss_blocks @ self.loss_weight
        return loss

    def calc_loss_smooth2(self, _x: Tensor, _y: Tensor, T_ys: Sequence[int], kern: int , stride: int ,pad: int = 0) -> Tensor:
        """
        out_blocks: B, depth, C, F, T
        y: B, C, F, T
        """

        crit = self.criterion 

        x = F.max_pool2d(_x, (kern, 1), stride = stride ) 
        y = F.max_pool2d(_y, (kern, 1), stride = stride ) 

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            loss_no_red = crit(x, y)

        loss_blocks = torch.zeros(x.shape[1], device=y.device)

        tot =0 
        for T, loss_batch in zip(T_ys, loss_no_red):
            tot += T
            loss_blocks += torch.sum(loss_batch[..., :T])
        loss_blocks = loss_blocks / tot

        if len(loss_blocks) == 1:
            loss1 = loss_blocks.squeeze()
        else:
            loss1 = loss_blocks @ self.loss_weight

        x = F.max_pool2d(-1*_x, (kern, 1), stride = stride ) 
        y = F.max_pool2d(-1*_y, (kern, 1), stride = stride ) 

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            loss_no_red = crit(x, y)

        loss_blocks = torch.zeros(x.shape[1], device=y.device)

        tot =0 
        for T, loss_batch in zip(T_ys, loss_no_red):
            tot += T
            loss_blocks += torch.sum(loss_batch[..., :T])
        loss_blocks = loss_blocks / tot

        if len(loss_blocks) == 1:
            loss2 = loss_blocks.squeeze()
        else:
            loss2 = loss_blocks @ self.loss_weight

        loss = loss1 + loss2
        return loss





    @torch.no_grad()
    def should_stop(self, loss_valid, epoch):
        if epoch == self.max_epochs - 1:
            return True
        self.scheduler.step(loss_valid)
        # if self.scheduler.t_epoch == 0:  # if it is restarted now
        #     # if self.loss_last_restart < loss_valid:
        #     #     return True
        #     if self.loss_last_restart * hp.threshold_stop < loss_valid:
        #         self.max_epochs = epoch + self.scheduler.restart_period + 1
        #     self.loss_last_restart = loss_valid




    def train(self, loader_train: DataLoader, loader_valid: DataLoader,
              logdir: Path, first_epoch=0):

        os.makedirs(Path(logdir), exist_ok=True)
        self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch)

        # Start Training
        step = 0

        loss_valid = self.validate(loader_valid, logdir, 0)
        l2_factor = hp.l2_factor
        
        num_filters = len(self.filters)

        for epoch in range(first_epoch, hp.n_epochs):
            self.writer.add_scalar('meta/lr', self.optimizer.param_groups[0]['lr'], epoch)
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True)
            avg_loss1 = AverageMeter(float)
            avg_loss2 = AverageMeter(float)

            avg_loss_tot = AverageMeter(float)
            avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
            losses = [None] * num_filters

            avg_grad_norm = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                ##import pdb; pdb.set_trace()
                y = self.preprocess(data)

                x_mel = self.model.spec_to_mel(y) 

                T_ys = data['T_ys']
                # forward
                x = self.model(x_mel) 

                y_mel = self.model.spec_to_mel(x)  


                step = step + 1

                loss1 = self.calc_loss(x    , y    , T_ys, self.criterion)
                loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2)
                loss = loss1+ l2_factor*loss2

                # for i,f in enumerate(self.filters):
                #     s = self.f_specs[i][1]
                #     losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s )
                #     loss = loss + losses[i]

                for i,(k,s) in enumerate(self.f_specs):
                    losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s )
                    loss = loss + losses[i]
            
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                           hp.thr_clip_grad)

                self.optimizer.step()

                # print
                avg_loss1.update(loss1.item(), len(T_ys))
                avg_loss2.update(loss2.item(), len(T_ys))
                avg_loss_tot.update(loss.item(), len(T_ys))

                for j,l in enumerate(losses):
                    avg_losses[j].update(l.item(), len(T_ys))

                
                pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}')
                avg_grad_norm.update(grad_norm)

                if i_iter % 25 == 0:
                    self.writer.add_scalar('loss/loss1_train', avg_loss1.get_average(), epoch*len(loader_train)+ i_iter)
                    self.writer.add_scalar('loss/loss2_train', avg_loss2.get_average(), epoch*len(loader_train)+ i_iter)

                    for j, avg_loss in enumerate(avg_losses):
                        k = self.f_specs[j][0]
                        s = self.f_specs[j][1]
                        self.writer.add_scalar(f'loss/losses_{k}_{s}_train', avg_loss.get_average(), epoch*len(loader_train)+ i_iter)
                    self.writer.add_scalar('loss/loss_total_train', avg_loss_tot.get_average(), epoch*len(loader_train)+ i_iter)

                    self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch*len(loader_train) +  i_iter)

                    avg_loss1 = AverageMeter(float)
                    avg_loss2 = AverageMeter(float)
                    avg_loss_tot = AverageMeter(float)
                    avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
                    avg_grad_norm = AverageMeter(float)


            # Validation
            # loss_valid = self.validate(loader_valid, logdir, epoch)
            loss_valid = self.validate(loader_valid, logdir, epoch+1)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save(
                    (self.module.state_dict(),
                     self.optimizer.state_dict(),
                     self.scheduler.state_dict(),
                     ),
                    logdir / f'{epoch+1}.pt'
                )

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()

    def audio_from_mag_spec(self, mag_spec):
        mag_spec = mag_spec.astype(np.float64)
        spec_lws = self.lws_processor.run_lws(np.transpose(mag_spec))
        magspec_inv = self.lws_processor.istft(spec_lws)[:, np.newaxis, np.newaxis]
        magspec_inv = magspec_inv.astype('float32')
        return magspec_inv


    @torch.no_grad()
    def validate(self, loader: DataLoader, logdir: Path, epoch: int):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        self.model.eval()

        num_filters = len(self.filters)
        avg_stoi = AverageMeter(float)
        avg_stoi_norm = AverageMeter(float)
        avg_stoi_base = AverageMeter(float)


        avg_loss1 = AverageMeter(float)
        avg_lozz1 = AverageMeter(float)
        avg_loss2 = AverageMeter(float)
        avg_lozz2 = AverageMeter(float)

        avg_loss_tot = AverageMeter(float)
        avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
        losses = [None] * num_filters

        pbar = tqdm(loader, desc='validate ', postfix='[0]', dynamic_ncols=True)
        num_iters = len(pbar)

        for i_iter, data in enumerate(pbar):

            ##import pdb; pdb.set_trace()
            y = self.preprocess(data)  # B, C, F, T
            x_mel = self.model.spec_to_mel(y) 

            z = self.model.mel_pseudo_inverse(x_mel)

            T_ys = data['T_ys']

            x = self.model(x_mel)  # B, C, F, T
            y_mel = self.model.spec_to_mel(x)     
            z_mel = self.model.spec_to_mel(y)

            loss1 = self.calc_loss(x, y, T_ys, self.criterion)
            lozz1 = self.calc_loss(z, y, T_ys, self.criterion)

            loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2)
            lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2)

            loss = loss1 + loss2*hp.l2_factor

            # for i,f in enumerate(self.filters):
            #     s = self.f_specs[i][1]
            #     losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s )
            #     loss = loss + losses[i]

            for i,(k,s) in enumerate(self.f_specs):
                losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s )
                loss = loss + losses[i]

            avg_loss1.update(loss1.item(), len(T_ys))
            avg_lozz1.update(lozz1.item(), len(T_ys))
            avg_loss2.update(loss2.item(), len(T_ys))
            avg_lozz2.update(lozz2.item(), len(T_ys))
            avg_loss_tot.update(loss.item(), len(T_ys))

            for j,l in enumerate(losses):
                avg_losses[j].update(l.item(), len(T_ys))

            # print
            pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}')

            ## STOI evaluation with LWS
            for p in range(min(hp.num_stoi// num_iters,len(T_ys))):

                _x = x[p,0,:,:T_ys[p]].cpu()
                _y = y[p,0,:,:T_ys[p]].cpu()
                _z = z[p,0,:,:T_ys[p]].cpu()

                audio_x = self.audio_from_mag_spec(np.abs(_x.numpy()))
                y_wav = data['wav'][p]

                stoi_score= self.calc_stoi(y_wav, audio_x)
                avg_stoi.update(stoi_score)

                if not i_iter in self.prev_stoi_scores:
                    audio_y = self.audio_from_mag_spec(_y.numpy())
                    audio_z = self.audio_from_mag_spec(_z.numpy())

                    self.prev_stoi_scores[i_iter] = self.calc_stoi(y_wav, audio_y)
                    self.base_stoi_scores[i_iter] = self.calc_stoi(y_wav, audio_z)

                avg_stoi_norm.update( stoi_score / self.prev_stoi_scores[i_iter])
                avg_stoi_base.update( stoi_score / self.base_stoi_scores[i_iter])

            # write summary
            ## if i_iter < 4:
            if False: ## stoi is good enough until tests
                x = x[0,0,:,:T_ys[0]].cpu()
                y = y[0,0,:,:T_ys[0]].cpu()
                z = z[0,0,:,:T_ys[0]].cpu()

                ##import pdb; pdb.set_trace()

                if i_iter == 3 and hp.request_drawings:
                    ymin = y[y > 0].min()
                    vmin, vmax = librosa.amplitude_to_db(np.array((ymin, y.max())))
                    kwargs_fig = dict(vmin=vmin, vmax=vmax)
                    fig_x = draw_spectrogram(x, **kwargs_fig)


                    ##self.add_figure(f'{self.group}Audio{idx}/0_Noisy_Spectrum', fig_x, step)
                    self.writer.add_figure(f'Audio{i_iter}/1_DNN_Output', fig_x, epoch)

                    if epoch ==0:
                        fig_y = draw_spectrogram(y, **kwargs_fig)
                        fig_z = draw_spectrogram(z, **kwargs_fig)
                        self.writer.add_figure(f'Audio{i_iter}/0_Pseudo_Inverse', fig_z, epoch)
                        self.writer.add_figure(f'Audio{i_iter}/2_Real_Spectrogram', fig_y, epoch)

                else:
                    audio_x = self.audio_from_mag_spec(np.abs(x.numpy()))


                    x_scale = np.abs(audio_x).max() / 0.5



                    self.writer.add_audio(f'Audio{i_iter}/1_DNN_Output',
                                torch.from_numpy(audio_x / x_scale),
                                epoch,
                                sample_rate=hp.sampling_rate)
                    if epoch ==0:

                        audio_y = self.audio_from_mag_spec(y.numpy())
                        audio_z = self.audio_from_mag_spec(z.numpy())
                        
                        z_scale = np.abs(audio_z).max() / 0.5
                        y_scale = np.abs(audio_y).max() / 0.5

                        self.writer.add_audio(f'Audio{i_iter}/0_Pseudo_Inverse',
                                    torch.from_numpy(audio_z / z_scale),
                                    epoch,
                                    sample_rate=hp.sampling_rate)


                        self.writer.add_audio(f'Audio{i_iter}/2_Real_Spectrogram',
                                    torch.from_numpy(audio_y / y_scale),
                                    epoch,
                                    sample_rate=hp.sampling_rate)


        self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch)
        self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch)
        self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch)
        self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch)
        self.writer.add_scalar(f'valid/STOI', avg_stoi.get_average(), epoch )
        self.writer.add_scalar(f'valid/STOI_normalized', avg_stoi_norm.get_average(), epoch )
        self.writer.add_scalar(f'valid/STOI_improvement', avg_stoi_base.get_average(), epoch )

        for j, avg_loss in enumerate(avg_losses):
            k = self.f_specs[j][0]
            s = self.f_specs[j][1]
            self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch)
        self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch)

        self.model.train()

        return avg_loss1.get_average()

    def calc_stoi(self, y_wav, audio):

        audio_len = min(y_wav.shape[0], audio.shape[0]  )
        measure = calc_using_eval_module(y_wav[:audio_len], audio[:audio_len,0,0])
        return measure['STOI']

    @torch.no_grad()
    def test(self, loader: DataLoader, logdir: Path):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        self.model.eval()

        os.makedirs(Path(logdir), exist_ok=True)
        self.writer = CustomWriter(str(logdir), group='test')

        ##import pdb; pdb.set_trace()
        num_filters = len(self.filters)

        avg_loss1 = AverageMeter(float)
        avg_lozz1 = AverageMeter(float)
        avg_loss2 = AverageMeter(float)
        avg_lozz2 = AverageMeter(float)

        avg_loss_tot = AverageMeter(float)
        avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
        losses = [None] * num_filters

        cnt = 0
        for i_iter, data in enumerate(loader):

            ##import pdb; pdb.set_trace()
            y = self.preprocess(data)  # B, C, F, T
            x_mel = self.model.spec_to_mel(y) 

            z = self.model.mel_pseudo_inverse(x_mel)

            T_ys = data['T_ys']
            x = self.model(x_mel)  # B, C, F, T
            y_mel = self.model.spec_to_mel(x)     
            z_mel = self.model.spec_to_mel(y)

            loss1 = self.calc_loss(x, y, T_ys, self.criterion)
            lozz1 = self.calc_loss(z, y, T_ys, self.criterion)

            loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2)
            lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2)

            loss = loss1 + loss2*hp.l2_factor

            # for i,f in enumerate(self.filters):
            #     s = self.f_specs[i][1]
            #     losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s )
            #     loss = loss + losses[i]

            for i,(k,s) in enumerate(self.f_specs):
                losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s )
                loss = loss + losses[i]

            avg_loss1.update(loss1.item(), len(T_ys))
            avg_lozz1.update(lozz1.item(), len(T_ys))
            avg_loss2.update(loss2.item(), len(T_ys))
            avg_lozz2.update(lozz2.item(), len(T_ys))
            avg_loss_tot.update(loss.item(), len(T_ys))

            for j,l in enumerate(losses):
                avg_losses[j].update(l.item(), len(T_ys))

            # print
            ##pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}')

            # write summary

            pbar = tqdm(range(len(T_ys)), desc='validate_bath', postfix='[0]', dynamic_ncols=True)

            for p in pbar:
                _x = x[p,0,:,:T_ys[p]].cpu()
                _y = y[p,0,:,:T_ys[p]].cpu()
                _z = z[p,0,:,:T_ys[p]].cpu()
                y_wav = data['wav'][p]

                ymin = _y[_y > 0].min()
                vmin, vmax = librosa.amplitude_to_db(np.array((ymin, _y.max())))
                kwargs_fig = dict(vmin=vmin, vmax=vmax)


                if hp.request_drawings:
                    fig_x = draw_spectrogram(_x, **kwargs_fig)
                    self.writer.add_figure(f'Audio/1_DNN_Output', fig_x, cnt)
                    fig_y = draw_spectrogram(_y, **kwargs_fig)
                    fig_z = draw_spectrogram(_z, **kwargs_fig)
                    self.writer.add_figure(f'Audio/0_Pseudo_Inverse', fig_z, cnt)
                    self.writer.add_figure(f'Audio/2_Real_Spectrogram', fig_y, cnt)

                audio_x = self.audio_from_mag_spec(np.abs(_x.numpy()))
                x_scale = np.abs(audio_x).max() / 0.5

                self.writer.add_audio(f'LWS/1_DNN_Output',
                            torch.from_numpy(audio_x / x_scale),
                            cnt,
                            sample_rate=hp.sampling_rate)

                audio_y = self.audio_from_mag_spec(_y.numpy())
                audio_z = self.audio_from_mag_spec(_z.numpy())
                
                z_scale = np.abs(audio_z).max() / 0.5
                y_scale = np.abs(audio_y).max() / 0.5

                self.writer.add_audio(f'LWS/0_Pseudo_Inverse',
                            torch.from_numpy(audio_z / z_scale),
                            cnt,
                            sample_rate=hp.sampling_rate)


                self.writer.add_audio(f'LWS/2_Real_Spectrogram',
                            torch.from_numpy(audio_y / y_scale),
                            cnt,
                            sample_rate=hp.sampling_rate)

                ##import pdb; pdb.set_trace()

                stoi_scores = {'0_Pseudo_Inverse'       : self.calc_stoi(y_wav, audio_z),
                               '1_DNN_Output'           : self.calc_stoi(y_wav, audio_x),
                               '2_Real_Spectrogram'     : self.calc_stoi(y_wav, audio_y)}

                self.writer.add_scalars(f'LWS/STOI', stoi_scores, cnt )
                # self.writer.add_scalar(f'STOI/0_Pseudo_Inverse_LWS', self.calc_stoi(y_wav, audio_z) , cnt)
                # self.writer.add_scalar(f'STOI/1_DNN_Output_LWS', self.calc_stoi(y_wav, audio_x) , cnt)
                # self.writer.add_scalar(f'STOI/2_Real_Spectrogram_LWS', self.calc_stoi(y_wav, audio_y) , cnt)
                cnt = cnt + 1

        # self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch)
        # self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch)
        # self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch)
        # self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch)

        # for j, avg_loss in enumerate(avg_losses):
        #     k = self.f_specs[j][0]
        #     s = self.f_specs[j][1]
        #     self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch)
        # self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch)

        self.model.train()

        return 

    @torch.no_grad()
    def inspect(self, loader: DataLoader, logdir: Path):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        self.model.eval()

        os.makedirs(Path(logdir), exist_ok=True)
        self.writer = CustomWriter(str(logdir), group='test')

        ##import pdb; pdb.set_trace()
        num_filters = len(self.filters)

        avg_loss1 = AverageMeter(float)
        avg_lozz1 = AverageMeter(float)
        avg_loss2 = AverageMeter(float)
        avg_lozz2 = AverageMeter(float)

        avg_loss_tot = AverageMeter(float)
        avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
        avg_losses_base = [AverageMeter(float) for _ in range(num_filters) ]
        losses = [None] * num_filters
        losses_base = [None] * num_filters

        cnt = 0

        pbar = tqdm(enumerate(loader), desc='loss inspection', dynamic_ncols=True)

        for i_iter, data in pbar:

            ##import pdb; pdb.set_trace()
            y = self.preprocess(data)  # B, C, F, T
            x_mel = self.model.spec_to_mel(y) 

            z = self.model.mel_pseudo_inverse(x_mel)

            T_ys = data['T_ys']
            x = self.model(x_mel)  # B, C, F, T
            y_mel = self.model.spec_to_mel(x)     
            z_mel = self.model.spec_to_mel(y)

            loss1 = self.calc_loss(x, y, T_ys, self.criterion)
            lozz1 = self.calc_loss(z, y, T_ys, self.criterion)

            loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2)
            lozz2 = self.calc_loss(z_mel, x_mel, T_ys, self.criterion2)

            loss = loss1 + loss2*hp.l2_factor

            # for i,f in enumerate(self.filters):
            #     s = self.f_specs[i][1]
            #     losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s )
            #     loss = loss + losses[i]

            for i,(k,s) in enumerate(self.f_specs):
                losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s )
                losses_base[i] = self.calc_loss_smooth2(y,y,T_ys,k, s )

                loss = loss + losses[i]
            avg_loss1.update(loss1.item(), len(T_ys))
            avg_lozz1.update(lozz1.item(), len(T_ys))
            avg_loss2.update(loss2.item(), len(T_ys))
            avg_lozz2.update(lozz2.item(), len(T_ys))
            avg_loss_tot.update(loss.item(), len(T_ys))

            for j,l in enumerate(losses):
                avg_losses[j].update(l.item(), len(T_ys))
                
            for j,l in enumerate(losses_base):
                avg_losses_base[j].update(l.item(), len(T_ys))                
            # print
            ##pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}')

            # write summary

            if 0:
                for p in range(len(T_ys)):
                    _x = x[p,0,:,:T_ys[p]].cpu()
                    _y = y[p,0,:,:T_ys[p]].cpu()
                    _z = z[p,0,:,:T_ys[p]].cpu()
                    y_wav = data['wav'][p]

                    ymin = _y[_y > 0].min()
                    vmin, vmax = librosa.amplitude_to_db(np.array((ymin, _y.max())))
                    kwargs_fig = dict(vmin=vmin, vmax=vmax)


                    if hp.request_drawings:
                        fig_x = draw_spectrogram(_x, **kwargs_fig)
                        self.writer.add_figure(f'Audio/1_DNN_Output', fig_x, cnt)
                        fig_y = draw_spectrogram(_y, **kwargs_fig)
                        fig_z = draw_spectrogram(_z, **kwargs_fig)
                        self.writer.add_figure(f'Audio/0_Pseudo_Inverse', fig_z, cnt)
                        self.writer.add_figure(f'Audio/2_Real_Spectrogram', fig_y, cnt)

                    audio_x = self.audio_from_mag_spec(np.abs(_x.numpy()))
                    x_scale = np.abs(audio_x).max() / 0.5

                    self.writer.add_audio(f'LWS/1_DNN_Output',
                                torch.from_numpy(audio_x / x_scale),
                                cnt,
                                sample_rate=hp.sampling_rate)

                    audio_y = self.audio_from_mag_spec(_y.numpy())
                    audio_z = self.audio_from_mag_spec(_z.numpy())
                    
                    z_scale = np.abs(audio_z).max() / 0.5
                    y_scale = np.abs(audio_y).max() / 0.5

                    self.writer.add_audio(f'LWS/0_Pseudo_Inverse',
                                torch.from_numpy(audio_z / z_scale),
                                cnt,
                                sample_rate=hp.sampling_rate)


                    self.writer.add_audio(f'LWS/2_Real_Spectrogram',
                                torch.from_numpy(audio_y / y_scale),
                                cnt,
                                sample_rate=hp.sampling_rate)

                    ##import pdb; pdb.set_trace()

                    stoi_scores = {'0_Pseudo_Inverse'       : self.calc_stoi(y_wav, audio_z),
                                '1_DNN_Output'           : self.calc_stoi(y_wav, audio_x),
                                '2_Real_Spectrogram'     : self.calc_stoi(y_wav, audio_y)}

                    self.writer.add_scalars(f'LWS/STOI', stoi_scores, cnt )
                    # self.writer.add_scalar(f'STOI/0_Pseudo_Inverse_LWS', self.calc_stoi(y_wav, audio_z) , cnt)
                    # self.writer.add_scalar(f'STOI/1_DNN_Output_LWS', self.calc_stoi(y_wav, audio_x) , cnt)
                    # self.writer.add_scalar(f'STOI/2_Real_Spectrogram_LWS', self.calc_stoi(y_wav, audio_y) , cnt)
                    cnt = cnt + 1

        for j, avg_loss in enumerate(avg_losses):
            k = self.f_specs[j][0]
            s = self.f_specs[j][1]
            self.writer.add_scalar(f'inspect/losses_breakdown', avg_loss.get_average(), j)

        for j, avg_loss in enumerate(avg_losses_base):
            k = self.f_specs[j][0]
            s = self.f_specs[j][1]
            self.writer.add_scalar(f'inspect/losses_base_breakdown', avg_loss.get_average(), j)

        for j, avg_loss in enumerate(avg_losses):
            avg_loss2 = avg_losses_base[j]
            k = self.f_specs[j][0]
            s = self.f_specs[j][1]
            self.writer.add_scalar(f'inspect/losses_normalized_breakdown', avg_loss2.get_average() / avg_loss.get_average() , j)


        # self.writer.add_scalar(f'valid/loss', avg_loss1.get_average(), epoch)
        # self.writer.add_scalar(f'valid/baseline', avg_lozz1.get_average(), epoch)
        # self.writer.add_scalar(f'valid/melinv_loss', avg_loss2.get_average(), epoch)
        # self.writer.add_scalar(f'valid/melinv_baseline', avg_lozz2.get_average(), epoch)

        # for j, avg_loss in enumerate(avg_losses):
        #     k = self.f_specs[j][0]
        #     s = self.f_specs[j][1]
        #     self.writer.add_scalar(f'valid/losses_{k}_{s}', avg_loss.get_average(), epoch)
        # self.writer.add_scalar('valid/loss_total', avg_loss_tot.get_average(), epoch)

        self.model.train()

        return 


    @torch.no_grad()
    def infer(self, loader: DataLoader, logdir: Path):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        def save_feature(num_snr, i_speech: int, s_path_speech: str, speech: ndarray, mag_mel2spec) -> tuple:
            spec_clean = np.ascontiguousarray(librosa.stft(speech, **hp.kwargs_stft))
            mag_clean = np.ascontiguousarray(np.abs(spec_clean)[..., np.newaxis])
            

            signal_power = np.mean(np.abs(speech)**2)
            list_dict = []
            list_snr_db = []
            for _ in enumerate(range(num_snr)):
                snr_db = -6*np.random.rand()
                list_snr_db.append(snr_db)
                snr = librosa.db_to_power(snr_db)
                noise_power = signal_power / snr
                noisy = speech + np.sqrt(noise_power) * np.random.randn(len(speech))
                spec_noisy = librosa.stft(noisy, **hp.kwargs_stft)
                spec_noisy = np.ascontiguousarray(spec_noisy)

                list_dict.append(
                    dict(spec_noisy=spec_noisy,
                        speech=speech,
                        spec_clean=spec_clean,
                        mag_clean=mag_mel2spec,
                        path_speech=s_path_speech,
                        length=len(speech),
                        )
                )
            return list_snr_db, list_dict


        self.model.eval()

        os.makedirs(Path(logdir), exist_ok=True)

        ##import pdb; pdb.set_trace()
        cnt = 0

        pbar = tqdm(loader, desc='mel2inference', postfix='[0]', dynamic_ncols=True)

        form= '{:05d}_mel2spec_{:+.2f}dB.npz' 
        num_snr = hp.num_snr
        for i_iter, data in enumerate(pbar):

            ##import pdb; pdb.set_trace()
            y = self.preprocess(data)  # B, C, F, T
            x_mel = self.model.spec_to_mel(y) 

            T_ys = data['T_ys']
            x = self.model(x_mel)  # B, C, F, T

            for p in range(len(T_ys)):
                _x = x[p,0,:,:T_ys[p]].unsqueeze(2).cpu().numpy()
                ##import pdb; pdb.set_trace()
                speech = data['wav'][p].numpy()

                list_snr_db, list_dict = save_feature(num_snr, cnt, data['path_speech'][p] , speech, _x)
                cnt = cnt + 1
                for snr_db, dict_result in zip(list_snr_db, list_dict):
                    np.savez(logdir / form.format(cnt, snr_db),
                            **dict_result,
                            )
        self.model.train()

        return 
Example #13
0
    def train(self, loader_train: DataLoader, loader_valid: DataLoader,
              logdir: Path, first_epoch=0):

        os.makedirs(Path(logdir), exist_ok=True)
        self.writer = CustomWriter(str(logdir), group='train', purge_step=first_epoch)

        # Start Training
        step = 0

        loss_valid = self.validate(loader_valid, logdir, 0)
        l2_factor = hp.l2_factor
        
        num_filters = len(self.filters)

        for epoch in range(first_epoch, hp.n_epochs):
            self.writer.add_scalar('meta/lr', self.optimizer.param_groups[0]['lr'], epoch)
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}', postfix='[]', dynamic_ncols=True)
            avg_loss1 = AverageMeter(float)
            avg_loss2 = AverageMeter(float)

            avg_loss_tot = AverageMeter(float)
            avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
            losses = [None] * num_filters

            avg_grad_norm = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                ##import pdb; pdb.set_trace()
                y = self.preprocess(data)

                x_mel = self.model.spec_to_mel(y) 

                T_ys = data['T_ys']
                # forward
                x = self.model(x_mel) 

                y_mel = self.model.spec_to_mel(x)  


                step = step + 1

                loss1 = self.calc_loss(x    , y    , T_ys, self.criterion)
                loss2 = self.calc_loss(x_mel, y_mel, T_ys, self.criterion2)
                loss = loss1+ l2_factor*loss2

                # for i,f in enumerate(self.filters):
                #     s = self.f_specs[i][1]
                #     losses[i] = self.calc_loss_smooth(x,y,T_ys,f, s )
                #     loss = loss + losses[i]

                for i,(k,s) in enumerate(self.f_specs):
                    losses[i] = self.calc_loss_smooth2(x,y,T_ys,k, s )
                    loss = loss + losses[i]
            
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                           hp.thr_clip_grad)

                self.optimizer.step()

                # print
                avg_loss1.update(loss1.item(), len(T_ys))
                avg_loss2.update(loss2.item(), len(T_ys))
                avg_loss_tot.update(loss.item(), len(T_ys))

                for j,l in enumerate(losses):
                    avg_losses[j].update(l.item(), len(T_ys))

                
                pbar.set_postfix_str(f'{avg_loss1.get_average():.1e}')
                avg_grad_norm.update(grad_norm)

                if i_iter % 25 == 0:
                    self.writer.add_scalar('loss/loss1_train', avg_loss1.get_average(), epoch*len(loader_train)+ i_iter)
                    self.writer.add_scalar('loss/loss2_train', avg_loss2.get_average(), epoch*len(loader_train)+ i_iter)

                    for j, avg_loss in enumerate(avg_losses):
                        k = self.f_specs[j][0]
                        s = self.f_specs[j][1]
                        self.writer.add_scalar(f'loss/losses_{k}_{s}_train', avg_loss.get_average(), epoch*len(loader_train)+ i_iter)
                    self.writer.add_scalar('loss/loss_total_train', avg_loss_tot.get_average(), epoch*len(loader_train)+ i_iter)

                    self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(), epoch*len(loader_train) +  i_iter)

                    avg_loss1 = AverageMeter(float)
                    avg_loss2 = AverageMeter(float)
                    avg_loss_tot = AverageMeter(float)
                    avg_losses = [AverageMeter(float) for _ in range(num_filters) ]
                    avg_grad_norm = AverageMeter(float)


            # Validation
            # loss_valid = self.validate(loader_valid, logdir, epoch)
            loss_valid = self.validate(loader_valid, logdir, epoch+1)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save(
                    (self.module.state_dict(),
                     self.optimizer.state_dict(),
                     self.scheduler.state_dict(),
                     ),
                    logdir / f'{epoch+1}.pt'
                )

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()