def __init__(self, config, resume: bool, model, loss_function, optimizer):
        self.n_gpu = torch.cuda.device_count()
        self.device = prepare_device(
            self.n_gpu, cudnn_deterministic=config["cudnn_deterministic"])

        self.optimizer = optimizer
        self.loss_function = loss_function

        self.model = model.to(self.device)

        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=list(
                                                   range(self.n_gpu)))

        # Trainer
        self.epochs = config["trainer"]["epochs"]
        self.save_checkpoint_interval = config["trainer"][
            "save_checkpoint_interval"]
        self.validation_config = config["trainer"]["validation"]
        self.train_config = config["trainer"].get("train", {})
        self.validation_interval = self.validation_config["interval"]
        self.find_max = self.validation_config["find_max"]
        self.validation_custom_config = self.validation_config["custom"]
        self.train_custom_config = self.train_config.get("custom", {})

        # The following args is not in the config file, We will update it if resume is True in later.
        self.start_epoch = 1
        self.best_score = -np.inf if self.find_max else np.inf
        self.root_dir = Path(config["root_dir"]).expanduser().absolute(
        ) / config["experiment_name"]
        self.checkpoints_dir = self.root_dir / "checkpoints"
        self.logs_dir = self.root_dir / "logs"
        prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

        self.writer = visualization.writer(self.logs_dir.as_posix())
        self.writer.add_text(
            tag="Configuration",
            text_string=
            f"<pre>  \n{json5.dumps(config, indent=4, sort_keys=False)}  \n</pre>",
            global_step=1)

        if resume: self._resume_checkpoint()
        if config["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        print("Configurations are as follows: ")
        print(json5.dumps(config, indent=2, sort_keys=False))

        with open((self.root_dir /
                   f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(),
                  "w") as handle:
            json5.dump(config, handle, indent=2, sort_keys=False)

        self._print_networks([self.model])
Пример #2
0
    def __init__(self, config, resume: bool, G, D, optim_G, optim_D, loss_function):
        self.n_gpus = torch.cuda.device_count()
        self.device = self._prepare_device(self.n_gpus, cudnn_deterministic=config["cudnn_deterministic"])

        self.optimizer_G = optim_G
        self.optimizer_D = optim_D

        self.loss_function = loss_function

        self.generator = G.to(self.device)
        self.discriminator = D.to(self.device)

        if self.n_gpus > 1:
            self.generator = torch.nn.DataParallel(self.generator, device_ids=list(range(self.n_gpus)))
            self.discriminator = torch.nn.DataParallel(self.discriminator, device_ids=list(range(self.n_gpus)))

        # The configuration items of trainer
        self.epochs = config["trainer"]["epochs"]
        self.save_checkpoint_interval = config["trainer"]["save_checkpoint_interval"]
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config["interval"]
        self.find_max = self.validation_config["find_max"]
        self.validation_custom_config = self.validation_config["custom"]

        # The following configuration items are not in the config file,we will update it if resume is True in later
        self.start_epoch = 1
        self.best_score = -np.inf if self.find_max else np.inf
        self.root_dir = Path(config["root_dir"]).expanduser().absolute() / config["experiment_name"]
        self.checkpoints_dir = self.root_dir / "checkpoints"
        self.logs_dir = self.root_dir / "logs"
        prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume)

        # Visualization
        self.writer = visualization.writer(self.logs_dir.as_posix())
        self.writer.add_text(
            tag="Configuration",
            text_string=f"<pre>  \n{json5.dumps(config, indent=4, sort_keys=False)}  \n</pre>",
            global_step=1
        )

        with open((self.root_dir / f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(), "w") as handle:
            json5.dump(config, handle, indent=2, sort_keys=False)

        if resume: self._resume_checkpoint()

        print("Configurations are as follows: ")
        print(json5.dumps(config, indent=2, sort_keys=False))

        self._print_networks([self.generator, self.discriminator])
Пример #3
0
    return librosa.load(os.path.abspath(os.path.expanduser(file_path)),
                        sr=16000)[0]


sr = 16000
n_samples = 4 * sr
reference_length = 5

dataset_list = [
    line.rstrip('\n') for line in open(
        os.path.abspath(
            os.path.expanduser(
                "/home/quhongling/experiments/Network-test/dev_dataset_path.txt"
            )), "r")
]
writer = visualization.writer(
    "/home/quhongling/experiments/Network-test/dataset_tmp/logs")

for item in tqdm(range(1260)):
    mixture_path, target_path, reference_path = dataset_list[item].split(" ")

    mixture = load_wav(mixture_path)
    target = load_wav(target_path)
    reference = load_wav(reference_path)

    print("\n", item + 1, "it\n", "mixture:", mixture, "\ntarget:", target,
          "\nreference:", reference, "\n")

    if len(reference) > (sr * reference_length):
        start = np.random.randint(len(reference) - sr * reference_length + 1)
        end = start + sr * reference_length
        reference = reference[start:end]
Пример #4
0
    def __init__(self, dist, rank, config, resume: bool, model, loss_function, optimizer):
        self.color_tool = colorful
        self.color_tool.use_style("solarized")

        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function

        # DistributedDataParallel (DDP)
        self.rank = rank
        self.dist = dist

        # Automatic mixed precision (AMP)
        self.use_amp = config["meta"]["use_amp"]
        self.scaler = GradScaler(enabled=self.use_amp)

        # Acoustics
        self.acoustic_config = config["acoustic"]

        # Supported STFT
        n_fft = self.acoustic_config["n_fft"]
        hop_length = self.acoustic_config["hop_length"]
        win_length = self.acoustic_config["win_length"]
        self.torch_stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.rank)
        self.istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length, device=self.rank)
        self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length)

        # Trainer.train in config
        self.train_config = config["trainer"]["train"]
        self.epochs = self.train_config["epochs"]
        self.save_checkpoint_interval = self.train_config["save_checkpoint_interval"]
        self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"]
        assert self.save_checkpoint_interval >= 1

        # Trainer.validation in config
        self.validation_config = config["trainer"]["validation"]
        self.validation_interval = self.validation_config["validation_interval"]
        self.save_max_metric_score = self.validation_config["save_max_metric_score"]
        assert self.validation_interval >= 1

        # Trainer.visualization in config
        self.visualization_config = config["trainer"]["visualization"]

        # In the 'train.py' file, if the 'resume' item is True, we will update the following args:
        self.start_epoch = 1
        self.best_score = -np.inf if self.save_max_metric_score else np.inf
        self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute() / config["meta"]["experiment_name"]
        self.checkpoints_dir = self.save_dir / "checkpoints"
        self.logs_dir = self.save_dir / "logs"

        if resume:
            self._resume_checkpoint()

        if config["meta"]["preloaded_model_path"]:
            self._preload_model(Path(config["preloaded_model_path"]))

        if self.rank == 0:
            prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

            self.writer = visualization.writer(self.logs_dir.as_posix())
            self.writer.add_text(
                tag="Configuration",
                text_string=f"<pre>  \n{toml.dumps(config)}  \n</pre>",
                global_step=1
            )

            print(self.color_tool.cyan("The configurations are as follows: "))
            print(self.color_tool.cyan("=" * 40))
            print(self.color_tool.cyan(toml.dumps(config)[:-1]))  # except "\n"
            print(self.color_tool.cyan("=" * 40))

            with open((self.save_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle:
                toml.dump(config, handle)

            self._print_networks([self.model])
Пример #5
0
    return librosa.load(os.path.abspath(os.path.expanduser(file_path)),
                        sr=16000)[0]


sr = 16000
n_samples = 32000
reference_length = 5

dataset_list = [
    line.rstrip('\n') for line in open(
        os.path.abspath(
            os.path.expanduser(
                "/home/quhongling/dataset/mix_2/dev/dev_dataset_path.txt")),
        "r")
]
writer = visualization.writer(
    "/home/quhongling/experiments/SpEx/dataset_tmp/logs")

for item in range(20):
    mixture_path, target_path, reference_path = dataset_list[item].split(" ")

    mixture = load_wav(mixture_path)
    target = load_wav(target_path)
    reference = load_wav(reference_path)

    if len(reference) > (sr * reference_length):
        start = np.random.randint(len(reference) - sr * reference_length + 1)
        end = start + sr * reference_length
        reference = reference[start:end]
    else:
        reference = np.pad(reference,
                           (0, sr * reference_length - len(reference)))