Example #1
0
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            for sub in self.module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        cnt_sample = 0
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            if 0 < hp.n_save_block_outs == i_iter:
                break
            _, output, residual = self.model(x,
                                             mag,
                                             max_length,
                                             repeat=hp.repeat_test)

            # write summary
            for i_b in range(len(T_ys)):
                i_sample = cnt_sample + i_b
                one_sample = ComplexSpecDataset.decollate_padded(
                    data, i_b)  # F, T, C

                out_one = self.postprocess(output, residual, T_ys, i_b,
                                           loader.dataset)

                ComplexSpecDataset.save_dirspec(
                    logdir / hp.form_result.format(i_sample), **one_sample,
                    **out_one)

                measure = self.writer.write_one(i_sample,
                                                **out_one,
                                                **one_sample,
                                                suffix=f'_{hp.repeat_test}')
                if avg_measure is None:
                    avg_measure = AverageMeter(init_value=measure)
                else:
                    avg_measure.update(measure)
                # print
                # str_measure = arr2str(measure).replace('\n', '; ')
                # pbar.write(str_measure)
            cnt_sample += len(T_ys)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text(f'{group}/Average Measure/Proposed',
                             str(avg_measure[0]))
        self.writer.add_text(f'{group}/Average Measure/Reverberant',
                             str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')
Example #2
0
    def train(self,
              loader_train: DataLoader,
              loader_valid: DataLoader,
              logdir: Path,
              first_epoch=0):
        self.writer = CustomWriter(str(logdir),
                                   group='train',
                                   purge_step=first_epoch)

        # Start Training
        for epoch in range(first_epoch, hp.n_epochs):
            self.writer.add_scalar('loss/lr',
                                   self.optimizer.param_groups[0]['lr'], epoch)
            print()
            pbar = tqdm(loader_train,
                        desc=f'epoch {epoch:3d}',
                        postfix='[]',
                        dynamic_ncols=True)
            avg_loss = AverageMeter(float)
            avg_grad_norm = AverageMeter(float)

            for i_iter, data in enumerate(pbar):
                # get data
                x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
                T_ys = data['T_ys']

                # forward
                output_loss, _, _ = self.model(
                    x, mag, max_length, repeat=hp.repeat_train)  # B, C, F, T

                loss = self.calc_loss(output_loss, y, T_ys)

                # backward
                self.optimizer.zero_grad()
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), hp.thr_clip_grad)

                self.optimizer.step()

                # print
                avg_loss.update(loss.item(), len(T_ys))
                pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')
                avg_grad_norm.update(grad_norm)

            self.writer.add_scalar('loss/train', avg_loss.get_average(), epoch)
            self.writer.add_scalar('loss/grad', avg_grad_norm.get_average(),
                                   epoch)

            # Validation
            # loss_valid = self.validate(loader_valid, logdir, epoch)
            loss_valid = self.validate(loader_valid,
                                       logdir,
                                       epoch,
                                       repeat=hp.repeat_train)

            # save loss & model
            if epoch % hp.period_save_state == hp.period_save_state - 1:
                torch.save((
                    self.module.state_dict(),
                    self.optimizer.state_dict(),
                    self.scheduler.state_dict(),
                ), logdir / f'{epoch}.pt')

            # Early stopping
            if self.should_stop(loss_valid, epoch):
                break

        self.writer.close()
Example #3
0
    def validate(self, loader: DataLoader, logdir: Path, epoch: int, repeat=1):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        suffix = f'_{repeat}' if repeat > 1 else ''

        self.model.eval()

        avg_loss = AverageMeter(float)

        pbar = tqdm(loader,
                    desc='validate ',
                    postfix='[0]',
                    dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output_loss, output, residual = self.model(x,
                                                       mag,
                                                       max_length,
                                                       repeat=repeat)

            # loss
            loss = self.calc_loss(output_loss, y, T_ys)
            avg_loss.update(loss.item(), len(T_ys))

            # print
            pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            # write summary
            if i_iter == 0:
                # F, T, C
                if not self.valid_eval_sample:
                    self.valid_eval_sample = ComplexSpecDataset.decollate_padded(
                        data, 0)

                out_one = self.postprocess(output, residual, T_ys, 0,
                                           loader.dataset)

                # ComplexSpecDataset.save_dirspec(
                #     logdir / hp.form_result.format(epoch),
                #     **self.valid_eval_sample, **out_one
                # )

                if not self.writer.reused_sample:
                    one_sample = self.valid_eval_sample
                else:
                    one_sample = dict()

                self.writer.write_one(epoch,
                                      **one_sample,
                                      **out_one,
                                      suffix=suffix)

        self.writer.add_scalar(f'loss/valid{suffix}', avg_loss.get_average(),
                               epoch)

        self.model.train()

        return avg_loss.get_average()
Example #4
0
        feat0 = criterion_cons(fake_features[0], real_features[0]) * 3.
        feat1 = criterion_cons(fake_features[1], real_features[1]) * 2.5
        feat2 = criterion_cons(fake_features[2], real_features[2]) * 2.
        feat3 = criterion_cons(fake_features[3], real_features[3]) * 1.5
        feat4 = criterion_cons(fake_features[4], real_features[4]) * 1.

        loss_feat = (feat0 + feat1 + feat2 + feat3 + feat4)

        optimizer_G.zero_grad()
        loss_feat.backward()
        optimizer_G.step()

        running_loss_feat.update(loss_feat, image.size(0))

        if global_step % steps_per_epoch == 0:
            epoch_loss_GD = running_loss_GD.get_average()
            epoch_loss_D = running_loss_D.get_average()
            epoch_loss_cons = running_loss_cons.get_average()
            epoch_loss_feat = running_loss_feat.get_average()

            running_loss_GD.reset()
            running_loss_D.reset()
            running_loss_cons.reset()
            running_loss_feat.reset()

            msg = "epoch- %d, loss_GD- %.4f, loss_cons- %.4f, loss_feat- %.4f, loss_D- %.4f" % (
                global_epoch, epoch_loss_GD, epoch_loss_cons, epoch_loss_feat,
                epoch_loss_D)

            logger.info(msg)