else:
    path_state_dict = None

# Training + Validation Set

# run
if args.infer and path_state_dict is None:
    trainer = None
    num_workers = 1
else:
    os.makedirs(logdir_train, exist_ok=True)
    trainer = Trainer(path_state_dict)
    num_workers = trainer.num_workers

if args.train:
    dataset_train = ComplexSpecDataset('train')
    dataset_valid = ComplexSpecDataset('valid')

    dataset_train.set_needs(**hp.channels)
    ##dataset_valid.set_needs(**hp.channels)
    loader_train = DataLoader(
        dataset_train,
        batch_size=hp.batch_size,
        num_workers=num_workers,
        collate_fn=dataset_train.pad_collate,
        pin_memory=(hp.device != 'cpu'),
        shuffle=True,
        drop_last=True,
    )
    loader_valid = DataLoader(
        dataset_valid,
Example #2
0
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()

        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            for sub in self.module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        cnt_sample = 0
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            if 0 < hp.n_save_block_outs == i_iter:
                break
            _, output, residual = self.model(x,
                                             mag,
                                             max_length,
                                             repeat=hp.repeat_test)

            # write summary
            for i_b in range(len(T_ys)):
                i_sample = cnt_sample + i_b
                one_sample = ComplexSpecDataset.decollate_padded(
                    data, i_b)  # F, T, C

                out_one = self.postprocess(output, residual, T_ys, i_b,
                                           loader.dataset)

                ComplexSpecDataset.save_dirspec(
                    logdir / hp.form_result.format(i_sample), **one_sample,
                    **out_one)

                measure = self.writer.write_one(i_sample,
                                                **out_one,
                                                **one_sample,
                                                suffix=f'_{hp.repeat_test}')
                if avg_measure is None:
                    avg_measure = AverageMeter(init_value=measure)
                else:
                    avg_measure.update(measure)
                # print
                # str_measure = arr2str(measure).replace('\n', '; ')
                # pbar.write(str_measure)
            cnt_sample += len(T_ys)

        self.model.train()

        avg_measure = avg_measure.get_average()

        self.writer.add_text(f'{group}/Average Measure/Proposed',
                             str(avg_measure[0]))
        self.writer.add_text(f'{group}/Average Measure/Reverberant',
                             str(avg_measure[1]))
        self.writer.close()  # Explicitly close

        print()
        str_avg_measure = arr2str(avg_measure).replace('\n', '; ')
        print(f'Average: {str_avg_measure}')
Example #3
0
    def validate(self, loader: DataLoader, logdir: Path, epoch: int, repeat=1):
        """ Evaluate the performance of the model.

        :param loader: DataLoader to use.
        :param logdir: path of the result files.
        :param epoch:
        """
        suffix = f'_{repeat}' if repeat > 1 else ''

        self.model.eval()

        avg_loss = AverageMeter(float)

        pbar = tqdm(loader,
                    desc='validate ',
                    postfix='[0]',
                    dynamic_ncols=True)
        for i_iter, data in enumerate(pbar):
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T
            T_ys = data['T_ys']

            # forward
            output_loss, output, residual = self.model(x,
                                                       mag,
                                                       max_length,
                                                       repeat=repeat)

            # loss
            loss = self.calc_loss(output_loss, y, T_ys)
            avg_loss.update(loss.item(), len(T_ys))

            # print
            pbar.set_postfix_str(f'{avg_loss.get_average():.1e}')

            # write summary
            if i_iter == 0:
                # F, T, C
                if not self.valid_eval_sample:
                    self.valid_eval_sample = ComplexSpecDataset.decollate_padded(
                        data, 0)

                out_one = self.postprocess(output, residual, T_ys, 0,
                                           loader.dataset)

                # ComplexSpecDataset.save_dirspec(
                #     logdir / hp.form_result.format(epoch),
                #     **self.valid_eval_sample, **out_one
                # )

                if not self.writer.reused_sample:
                    one_sample = self.valid_eval_sample
                else:
                    one_sample = dict()

                self.writer.write_one(epoch,
                                      **one_sample,
                                      **out_one,
                                      suffix=suffix)

        self.writer.add_scalar(f'loss/valid{suffix}', avg_loss.get_average(),
                               epoch)

        self.model.train()

        return avg_loss.get_average()
            os.makedirs(logdir_test)
        else:
            exit()
    os.makedirs(logdir_test, exist_ok=True)

# epoch, state dict
first_epoch = args.epoch + 1
if first_epoch > 0:
    path_state_dict = logdir_train / f'{args.epoch}.pt'
    if not path_state_dict.exists():
        raise FileNotFoundError(path_state_dict)
else:
    path_state_dict = None

# Training + Validation Set
dataset_temp = ComplexSpecDataset('train')
dataset_train, dataset_valid = ComplexSpecDataset.split(
    dataset_temp, (hp.train_ratio, -1))
dataset_train.set_needs(**hp.channels)
dataset_valid.set_needs(**hp.channels)

# run
trainer = Trainer(path_state_dict)
if args.train:
    loader_train = DataLoader(
        dataset_train,
        batch_size=hp.batch_size,
        num_workers=hp.num_workers,
        collate_fn=dataset_train.pad_collate,
        pin_memory=(hp.device != 'cpu'),
        shuffle=True,
Example #5
0
            os.makedirs(logdir_test)
        else:
            exit()
    os.makedirs(logdir_test, exist_ok=True)

# epoch, state dict
first_epoch = args.epoch + 1
if first_epoch > 0:
    path_state_dict = logdir_train / f'{args.epoch}.pt'
    if not path_state_dict.exists():
        raise FileNotFoundError(path_state_dict)
else:
    path_state_dict = None

# Training + Validation Set
dataset_temp = ComplexSpecDataset('train')
dataset_train, dataset_valid = ComplexSpecDataset.split(
    dataset_temp, (hp.train_ratio, -1))
dataset_train.set_needs(**hp.channels)
dataset_valid.set_needs(**hp.channels)

# run
trainer = Trainer(path_state_dict)
if args.train:
    loader_train = DataLoader(
        dataset_train,
        batch_size=hp.batch_size,
        num_workers=hp.num_workers,
        collate_fn=dataset_train.pad_collate,
        pin_memory=(hp.device != 'cpu'),
        shuffle=True,
    def test(self, loader: DataLoader, logdir: Path):
        def save_forward(module: nn.Module, in_: Tensor, out: Tensor):
            module_name = str(module).split('(')[0]
            dict_to_save = dict()
            # dict_to_save['in'] = in_.detach().cpu().numpy().squeeze()
            dict_to_save['out'] = out.detach().cpu().numpy().squeeze()

            i_module = module_counts[module_name]
            for i, o in enumerate(dict_to_save['out']):
                save_forward.writer.add_figure(
                    f'{group}/blockout_{i_iter}/{module_name}{i_module}',
                    draw_spectrogram(o, to_db=False),
                    i,
                )
            scio.savemat(
                str(logdir / f'blockout_{i_iter}_{module_name}{i_module}.mat'),
                dict_to_save,
            )
            module_counts[module_name] += 1

        group = logdir.name.split('_')[0]

        if self.writer is None:
            self.writer = CustomWriter(str(logdir), group=group)

        avg_measure = None
        self.model.eval()
        depth = hp.model['depth']

        module_counts = None
        if hp.n_save_block_outs:
            module_counts = defaultdict(int)
            save_forward.writer = self.writer
            for sub in self.module.children():
                if isinstance(sub, nn.ModuleList):
                    for m in sub:
                        m.register_forward_hook(save_forward)
                elif isinstance(sub, nn.ModuleDict):
                    for m in sub.values():
                        m.register_forward_hook(save_forward)
                else:
                    sub.register_forward_hook(save_forward)

        ##pbar = tqdm(loader, desc=group, dynamic_ncols=True)
        cnt_sample = 0
        for i_iter, data in enumerate(loader):

            sampleDict = {}
            # get data
            x, mag, max_length, y = self.preprocess(data)  # B, C, F, T

            if hp.noisy_init:
                x = torch.normal(0, 1, x.shape).cuda(self.in_device)

            T_ys = data['T_ys']

            # forward
            if module_counts is not None:
                module_counts = defaultdict(int)

            # if 0 < hp.n_save_block_outs == i_iter:
            #     break
            repeats = 1

            for _ in range(3):
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=1,
                                                 train_step=1)  ##warn up!
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)

            while repeats <= hp.repeat_test:
                stime = ms()
                _, output, residual = self.model(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats,
                                                 train_step=1)
                avg_measure = AverageMeter()
                avg_measure2 = AverageMeter()

                etime = ms(stime)
                speed = (max_length / hp.fs) * len(T_ys) / (etime / 1000)
                ##print("degli: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed))
                ##self.writer.add_scalar("Test Performance/degli", speed, repeats)
                # write summary
                for i_b in tqdm(range(len(T_ys)),
                                desc="degli, %d repeats" % repeats,
                                dynamic_ncols=True):
                    i_sample = cnt_sample + i_b

                    if not i_b in sampleDict:
                        one_sample = ComplexSpecDataset.decollate_padded(
                            data, i_b)
                        reused_sample, result_eval_glim = self.writer.write_zero(
                            0, i_b, **one_sample, suffix="Base stats")
                        sampleDict[i_b] = (reused_sample, result_eval_glim)

                    sampleItem = sampleDict[i_b]
                    reused_sample = sampleItem[0]
                    result_eval_glim = sampleItem[1]

                    out_one = self.postprocess(output, residual, T_ys, i_b,
                                               loader.dataset)

                    # ComplexSpecDataset.save_dirspec(
                    #     logdir / hp.form_result.format(i_sample),
                    #     **one_sample, **out_one
                    # )

                    measure = self.writer.write_one(repeats,
                                                    i_b,
                                                    result_eval_glim,
                                                    reused_sample,
                                                    **out_one,
                                                    suffix="3_deGLI")

                    avg_measure.update(measure)

                stime = ms()
                _, output = self.model.plain_gla(x,
                                                 mag,
                                                 max_length,
                                                 repeat=repeats)

                etime = ms(stime)
                speed = (1000 * max_length / hp.fs) * len(T_ys) / (etime)
                ##print("pure gla: %d repeats, length: %d, time: %d miliseconds, ratio = %.02f" % (repeats, max_length , etime, speed))
                ##self.writer.add_scalar("Test Performance/gla", speed, repeats)

                # write summary
                for i_b in tqdm(range(len(T_ys)),
                                desc="GLA, %d repeats" % repeats,
                                dynamic_ncols=True):
                    i_sample = cnt_sample + i_b
                    sampleItem = sampleDict[i_b]
                    reused_sample = sampleItem[0]
                    result_eval_glim = sampleItem[1]
                    out_one = self.postprocess(output, None, T_ys, i_b,
                                               loader.dataset)
                    measure = self.writer.write_one(repeats,
                                                    i_b,
                                                    result_eval_glim,
                                                    reused_sample,
                                                    **out_one,
                                                    suffix="4_GLA")
                    avg_measure2.update(measure)

                cnt_sample += len(T_ys)

                self.writer.add_scalar(f'STOI/Average Measure/deGLI',
                                       avg_measure.get_average()[0, 0],
                                       repeats * depth)
                self.writer.add_scalar(f'STOI/Average Measure/GLA',
                                       avg_measure2.get_average()[0, 0],
                                       repeats * depth)
                self.writer.add_scalar(f'STOI/Average Measure/deGLI_semilogx',
                                       avg_measure.get_average()[0, 0],
                                       int(repeats * depth).bit_length())
                self.writer.add_scalar(f'STOI/Average Measure/GLA_semilogx',
                                       avg_measure2.get_average()[0, 0],
                                       int(repeats * depth).bit_length())

                self.writer.add_scalar(f'PESQ/Average Measure/deGLI',
                                       avg_measure.get_average()[0, 1],
                                       repeats * depth)
                self.writer.add_scalar(f'PESQ/Average Measure/GLA',
                                       avg_measure2.get_average()[0, 1],
                                       repeats * depth)
                self.writer.add_scalar(f'PESQ/Average Measure/deGLI_semilogx',
                                       avg_measure.get_average()[0, 1],
                                       int(repeats * depth).bit_length())
                self.writer.add_scalar(f'PESQ/Average Measure/GLA_semilogx',
                                       avg_measure2.get_average()[0, 1],
                                       int(repeats * depth).bit_length())

                repeats = repeats * 2
            break
        self.model.train()

        self.writer.close()  # Explicitly close