Ejemplo n.º 1
0
def main(config, resume):
    torch.manual_seed(config["seed"])  # both CPU and CUDA
    np.random.seed(config["seed"])

    test_dataloader = DataLoader(
        dataset=initialize_config(config["test_dataset"]),
        batch_size=1,
        num_workers=1,
    )

    generator = initialize_config(config["generator_model"])
    discriminator = initialize_config(config["discriminator_model"])

    generator_optimizer = torch.optim.Adam(
        params=generator.parameters(),
        lr=config["optimizer"]["G_lr"],
        betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]))
    discriminator_optimizer = torch.optim.Adam(
        params=discriminator.parameters(),
        lr=config["optimizer"]["D_lr"],
        betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]))

    additional_loss_function = initialize_config(
        config["additional_loss_function"])

    tester = Tester(config=config,
                    resume=resume,
                    generator=generator,
                    discriminator=discriminator,
                    generator_optimizer=generator_optimizer,
                    discriminator_optimizer=discriminator_optimizer,
                    additional_loss_function=additional_loss_function,
                    test_dl=test_dataloader)

    tester.test()
Ejemplo n.º 2
0
def main(config, checkpoint_path, output_dir):
    inferencer_class = initialize_config(config["inference"], pass_args=False)
    inferencer = inferencer_class(
        config,
        checkpoint_path,
        output_dir
    )
    inferencer.inference()
Ejemplo n.º 3
0
def main(config, resume):
    torch.manual_seed(config["seed"])  # both CPU and CUDA
    np.random.seed(config["seed"])

    train_dataloader = DataLoader(
        dataset=initialize_config(config["train_dataset"]),
        batch_size=config["train_dataloader"]["batch_size"],
        num_workers=config["train_dataloader"]["num_workers"],
        shuffle=config["train_dataloader"]["shuffle"],
        pin_memory=config["train_dataloader"][
            "pin_memory"]  # Very small data set False
    )
    validation_dataloader = DataLoader(
        dataset=initialize_config(config["validation_dataset"]),
        batch_size=1,
        num_workers=1,
    )

    generator = initialize_config(config["generator_model"])
    discriminator = initialize_config(config["discriminator_model"])

    generator_optimizer = torch.optim.Adam(
        params=generator.parameters(),
        lr=config["optimizer"]["G_lr"],
        betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]))
    discriminator_optimizer = torch.optim.Adam(
        params=discriminator.parameters(),
        lr=config["optimizer"]["D_lr"],
        betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]))

    additional_loss_function = initialize_config(
        config["additional_loss_function"])

    trainer = Trainer(
        config=config,
        resume=resume,
        generator=generator,
        discriminator=discriminator,
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer,
        additional_loss_function=additional_loss_function,
        train_dl=train_dataloader,
        validation_dl=validation_dataloader,
    )

    trainer.train()
 def _load_dataloader(dataset_config):
     dataset = initialize_config(dataset_config)
     dataloader = DataLoader(
         dataset=dataset,
         batch_size=1,
         num_workers=0,
     )
     return dataloader
Ejemplo n.º 5
0
def main(config, resume):
    torch.manual_seed(config["seed"])  # for both CPU and GPU
    np.random.seed(config["seed"])

    train_dataloader = DataLoader(
        dataset=initialize_config(config["train_dataset"]),
        batch_size=config["train_dataloader"]["batch_size"],
        num_workers=config["train_dataloader"]["num_workers"],
        shuffle=config["train_dataloader"]["shuffle"],
        pin_memory=config["train_dataloader"]["pin_memory"]
    )

    valid_dataloader = DataLoader(
        dataset=initialize_config(config["validation_dataset"]),
        num_workers=1,
        batch_size=1
    )

    model = initialize_config(config["model"])

    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=config["optimizer"]["lr"],
        betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"])
    )

    loss_function = initialize_config(config["loss_function"])

    trainer_class = initialize_config(config["trainer"], pass_args=False)

    trainer = trainer_class(
        config=config,
        resume=resume,
        model=model,
        loss_function=loss_function,
        optimizer=optimizer,
        train_dataloader=train_dataloader,
        validation_dataloader=valid_dataloader
    )

    trainer.train()
Ejemplo n.º 6
0
def main(config, resume):
    torch.manual_seed(config["seed"])  # both CPU and GPU
    np.random.seed(config["seed"])

    train_dataloader = DataLoader(
        dataset=initialize_config(config["train_dataset"]),
        batch_size=config["train_dataloader"]["batch_size"],
        num_workers=config["train_dataloader"]["num_workers"],
        shuffle=config["train_dataloader"]["shuffle"],
        pin_memory=config["train_dataloader"]
        ["pin_memory"],  # Set it to False for very small dataset, otherwise True.
        collate_fn=pad_to_longest_in_one_batch)
    validation_dataloader = DataLoader(
        dataset=initialize_config(config["validation_dataset"]),
        batch_size=1,
        num_workers=1,
    )

    model = initialize_config(config["model"])

    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config["optimizer"]["lr"],
                                 betas=(config["optimizer"]["beta1"],
                                        config["optimizer"]["beta2"]))

    loss_function = initialize_config(config["loss_function"])

    trainer = Trainer(
        config=config,
        resume=resume,
        model=model,
        optimizer=optimizer,
        loss_function=loss_function,
        train_dl=train_dataloader,
        validation_dl=validation_dataloader,
    )

    trainer.train()
Ejemplo n.º 7
0
def main(config, resume):

    torch.manual_seed(int(config["seed"]))  # both CPU and CUDA
    np.random.seed(config["seed"])

    train_dataloader = DataLoader(
        dataset=initialize_config(config["train_dataset"]),
        batch_size=config["train_dataloader"]["batch_size"],
        num_workers=config["train_dataloader"]["num_workers"],
        shuffle=config["train_dataloader"]["shuffle"],
        pin_memory=config["train_dataloader"][
            "pin_memory"]  # Very small data set False
    )

    validation_dataloader = DataLoader(dataset=initialize_config(
        config["validation_dataset"]),
                                       batch_size=1,
                                       num_workers=1)

    model = initialize_config(config["model"])

    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config["optimizer"]["lr"],
                                 betas=(config["optimizer"]["beta1"],
                                        config["optimizer"]["beta2"]))

    loss_function = initialize_config(config["loss_function"])

    trainer = GeneralTrainer(config=config,
                             resume=resume,
                             model=model,
                             optim=optimizer,
                             loss_fucntion=loss_function,
                             train_dl=train_dataloader,
                             validation_dl=validation_dataloader)

    trainer.train()
    def _load_model(model_config, checkpoint_path, device):
        model = initialize_config(model_config)
        if os.path.splitext(os.path.basename(checkpoint_path))[-1] == ".tar":
            model_checkpoint = torch.load(checkpoint_path, map_location=device)
            model_static_dict = model_checkpoint["model"]
            print(
                f"Loading model checkpoint with *.tar format, the epoch is: {model_checkpoint['epoch']}."
            )
        else:
            model_static_dict = torch.load(checkpoint_path,
                                           map_location=device)

        model.load_state_dict(model_static_dict)
        model.to(device)
        model.eval()
        return model
args = parser.parse_args()

"""
Preparation
"""
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
config = json.load(open(args.config))
model_checkpoint_path = args.model_checkpoint_path
output_dir = args.output_dir
assert os.path.exists(output_dir), "Enhanced directory should be exist."

"""
DataLoader
"""
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
dataloader = DataLoader(dataset=initialize_config(config["dataset"]), batch_size=1, num_workers=0)

"""
Model
"""
model = initialize_config(config["model"])
model.load_state_dict(load_checkpoint(model_checkpoint_path, device))
model.to(device)
model.eval()

"""
Enhancement
"""
sample_length = config["custom"]["sample_length"]
for mixture, name in tqdm(dataloader):
    assert len(name) == 1, "Only support batch size is 1 in enhancement stage."
Ejemplo n.º 10
0
                    help="Checkpoint.")
args = parser.parse_args()
"""
Preparation
"""
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
config = json.load(open(args.config))
model_checkpoint_path = args.model_checkpoint_path
output_dir = args.output_dir
assert os.path.exists(output_dir), "Enhanced directory should be exist."
"""
DataLoader
"""
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device(
    "cpu")
dataloader = DataLoader(dataset=initialize_config(config["dataset"]),
                        batch_size=1,
                        num_workers=0)
"""
Model
"""
model = initialize_config(config["model"])
model.load_state_dict(load_checkpoint(model_checkpoint_path, device))
model.to(device)
model.eval()
"""
Enhancement
"""
sample_length = dataloader.dataset.sample_length
for mixture, name in tqdm(dataloader):
    assert len(name) == 1, "Only support batch size is 1 in enhancement stage."
Ejemplo n.º 11
0
def main(config, resume):
    # Random seed for both CPU and GPU.
    torch.manual_seed(config["seed"])
    np.random.seed(config["seed"])

    def collate_fn_pad(batch):
        """
        Returns:
            [B, F, T (Longest)]
        """
        noisy_list = []
        clean_list = []
        n_frames_list = []
        names = []

        for noisy, clean, n_frames, name in batch:
            noisy_list.append(torch.tensor(noisy).permute(
                1, 0))  # [F, T] => [T, F]
            clean_list.append(torch.tensor(clean).permute(
                1, 0))  # [1, T] => [T, 1]
            n_frames_list.append(n_frames)
            names.append(name)

        # seq_list = [(T1, F), (T2, F), ...]
        #   item.size() must be (T, *)
        #   return (longest_T, len(seq_list), *)
        noisy_list = pad_sequence(noisy_list).permute(
            1, 2, 0)  # ([T1, F], [T2, F], ...) => [T, B, F] => [B, F, T]
        clean_list = pad_sequence(clean_list).permute(
            1, 2, 0)  # ([T1, 1], [T2, 1], ...) => [T, B, 1] => [B, 1, T]

        return noisy_list, clean_list, n_frames_list, names

    train_dataloader = DataLoader(
        dataset=initialize_config(config["train_dataset"]),
        batch_size=config["train_dataloader"]["batch_size"],
        num_workers=config["train_dataloader"]["num_workers"],
        shuffle=config["train_dataloader"]["shuffle"],
        pin_memory=config["train_dataloader"]["pin_memory"],
        collate_fn=collate_fn_pad)

    valid_dataloader = DataLoader(dataset=initialize_config(
        config["validation_dataset"]),
                                  num_workers=0,
                                  batch_size=1)

    model = initialize_config(config["model"])

    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config["optimizer"]["lr"],
                                 betas=(config["optimizer"]["beta1"],
                                        config["optimizer"]["beta2"]))

    loss_function = initialize_config(config["loss_function"])
    trainer_class = initialize_config(config["trainer"], pass_args=False)

    trainer = trainer_class(config=config,
                            resume=resume,
                            model=model,
                            loss_function=loss_function,
                            optimizer=optimizer,
                            train_dataloader=train_dataloader,
                            validation_dataloader=valid_dataloader)

    trainer.train()