コード例 #1
0
    def __init__(self, config, checkpoint_path, output_dir):
        checkpoint_path = Path(checkpoint_path).expanduser().absolute()
        root_dir = Path(output_dir).expanduser().absolute()
        self.device = prepare_device(torch.cuda.device_count())

        print("Loading inference dataset...")
        self.dataloader = self._load_dataloader(config["dataset"])
        print("Loading model...")
        self.model, epoch = self._load_model(config["model"], checkpoint_path, self.device)
        self.inference_config = config["inferencer"]

        self.enhanced_dir = root_dir / f"enhanced_{str(epoch).zfill(4)}"
        prepare_empty_dir([self.enhanced_dir])

        self.acoustic_config = config["acoustic"]
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]

        self.stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.device)
        self.istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.device)
        self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length)

        print("Configurations are as follows: ")
        print(toml.dumps(config))
        with open((root_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle:
            toml.dump(config, handle)
    def __init__(self, config, resume: bool, model, loss_function, optimizer):
        self.n_gpu = torch.cuda.device_count()
        self.device = prepare_device(
            self.n_gpu, cudnn_deterministic=config["cudnn_deterministic"])

        self.optimizer = optimizer
        self.loss_function = loss_function

        self.model = model.to(self.device)

        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=list(
                                                   range(self.n_gpu)))

        # Trainer
        self.epochs = config["trainer"]["epochs"]
        self.save_checkpoint_interval = config["trainer"][
            "save_checkpoint_interval"]
        self.validation_config = config["trainer"]["validation"]
        self.train_config = config["trainer"].get("train", {})
        self.validation_interval = self.validation_config["interval"]
        self.find_max = self.validation_config["find_max"]
        self.validation_custom_config = self.validation_config["custom"]
        self.train_custom_config = self.train_config.get("custom", {})

        # The following args is not in the config file, We will update it if resume is True in later.
        self.start_epoch = 1
        self.best_score = -np.inf if self.find_max else np.inf
        self.root_dir = Path(config["root_dir"]).expanduser().absolute(
        ) / config["experiment_name"]
        self.checkpoints_dir = self.root_dir / "checkpoints"
        self.logs_dir = self.root_dir / "logs"
        prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

        self.writer = visualization.writer(self.logs_dir.as_posix())
        self.writer.add_text(
            tag="Configuration",
            text_string=
            f"<pre>  \n{json5.dumps(config, indent=4, sort_keys=False)}  \n</pre>",
            global_step=1)

        if resume: self._resume_checkpoint()
        if config["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        print("Configurations are as follows: ")
        print(json5.dumps(config, indent=2, sort_keys=False))

        with open((self.root_dir /
                   f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(),
                  "w") as handle:
            json5.dump(config, handle, indent=2, sort_keys=False)

        self._print_networks([self.model])
    def __init__(self, config, checkpoint_path, output_dir):
        checkpoint_path = Path(checkpoint_path).expanduser().absolute()
        output_root_dir = Path(output_dir).expanduser().absolute()
        self.device = prepare_device(torch.cuda.device_count())

        self.enhanced_dir = output_root_dir / "enhanced"
        prepare_empty_dir([self.enhanced_dir])

        self.dataloader = self._load_dataloader(config["dataset"])
        self.model = self._load_model(config["model"], checkpoint_path,
                                      self.device)
        self.inference_config = config["inference"]

        print("Configurations are as follows: ")
        print(json5.dumps(config, indent=2, sort_keys=False))

        with open((output_root_dir /
                   f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(),
                  "w") as handle:
            json5.dump(config, handle, indent=2, sort_keys=False)
コード例 #4
0
def main(args):
    sr = args.sr
    metric_types = args.metric_types
    export_dir = args.export_dir
    specific_dataset = args.specific_dataset.lower()

    # 通过指定的 scp 文件或目录获取全部的 wav 样本
    reference_wav_paths, estimated_wav_paths = pre_processing(
        args.estimated, args.reference, specific_dataset)

    if export_dir:
        export_dir = Path(export_dir).expanduser().absolute()
        prepare_empty_dir([export_dir])

    print(f"=== {args.estimated} === {args.reference} ===")
    for metric_type in metric_types.split(","):
        metrics_result_store = compute_metric(reference_wav_paths,
                                              estimated_wav_paths,
                                              sr,
                                              metric_type=metric_type)

        # Print result
        metric_value = np.mean(list(zip(*metrics_result_store))[1])
        print(f"{metric_type}: {metric_value}")

        # Export result
        if export_dir:
            import tablib

            export_path = export_dir / f"{metric_type}.xlsx"
            print(f"Export result to {export_path}")

            headers = ("Speech", f"{metric_type}")
            metric_seq = [[basename, metric_value]
                          for basename, metric_value in metrics_result_store]
            data = tablib.Dataset(*metric_seq, headers=headers)
            with open(export_path.as_posix(), "wb") as f:
                f.write(data.export("xlsx"))
コード例 #5
0
    def __init__(self, config, resume, G, D, optim_G, optim_D,
                 additional_loss_function):
        self.n_gpu = torch.cuda.device_count()
        self.device = self._prepare_device(
            self.n_gpu, cudnn_deterministic=config["cudnn_deterministic"])

        self.generator = G.to(self.device)
        self.discriminator = D.to(self.device)

        if self.n_gpu > 1:
            self.generator = torch.nn.DataParallel(self.generator,
                                                   device_ids=list(
                                                       range(self.n_gpu)))
            self.discriminator = torch.nn.DataParallel(self.discriminator,
                                                       device_ids=list(
                                                           range(self.n_gpu)))

        self.optimizer_G = optim_G
        self.optimizer_D = optim_D

        self.additional_loss_function = additional_loss_function
        self.adversarial_loss_function = torch.nn.BCELoss()

        # Trainer
        self.epochs = config["trainer"]["epochs"]
        self.save_checkpoint_interval = config["trainer"][
            "save_checkpoint_interval"]
        self.soft_label = config["trainer"]["soft_label"]
        self.additional_loss_factor = config["trainer"][
            "additional_loss_factor"]
        self.adversarial_loss_factor = config["trainer"][
            "adversarial_loss_factor"]
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config["interval"]
        self.find_max = self.validation_config["find_max"]
        self.validation_custom_config = self.validation_config["custom"]

        self.start_epoch = 1
        self.best_score = -np.inf if self.find_max else np.inf
        self.root_dir = Path(config["root_dir"]).expanduser().absolute(
        ) / config["experiment_name"]
        self.checkpoints_dir = self.root_dir / "checkpoints"
        self.logs_dir = self.root_dir / "logs"
        prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

        self.writer = visualization.writer(self.logs_dir.as_posix())
        self.writer.add_text(
            tag="Configuration",
            text_string=
            f"<pre>  \n{json5.dumps(config, indent=4, sort_keys=False)}  \n</pre>",
            global_step=1)

        if resume: self._resume_checkpoint()

        print("Configurations are as follows: ")
        print(json5.dumps(config, indent=2, sort_keys=False))

        with open((self.root_dir /
                   f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(),
                  "w") as handle:
            json5.dump(config, handle, indent=2, sort_keys=False)

        self._print_networks([self.generator, self.discriminator])
コード例 #6
0
ファイル: trainer.py プロジェクト: xzm2004260/FullSubNet
    def __init__(self, dist, rank, config, resume: bool, model, loss_function, optimizer):
        self.color_tool = colorful
        self.color_tool.use_style("solarized")

        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustic"]

        # Supported STFT
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]
        self.torch_stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.rank)
        self.istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.rank)
        self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length)

        # Trainer.train in config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config["save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1

        # Trainer.validation in config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config["validation_interval"]
        self.save_max_metric_score = self.validation_config["save_max_metric_score"]
        assert self.validation_interval >= 1

        # Trainer.visualization in config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is True, we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute() / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

            self.writer = visualization.writer(self.logs_dir.as_posix())
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1
            )

            print(self.color_tool.cyan("The configurations are as follows: "))
            print(self.color_tool.cyan("=" * 40))
            print(self.color_tool.cyan(toml.dumps(config)[:-1]))  # except "\n"
            print(self.color_tool.cyan("=" * 40))

            with open((self.save_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])