Exemplo n.º 1
0
def train(args):
    num_bins, config_dict = parse_yaml(args.config)
    # reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnet_conf = config_dict["dcnet"]
    train_config = config_dict["trainer"]

    train_dataset = SpectrogramDataset(loader_conf["train_path_npz"])
    valid_dataset = SpectrogramDataset(loader_conf["valid_path_npz"])

    train_loader = DataLoader(train_dataset,
                              batch_size=loader_conf["batch_size"],
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=loader_conf["batch_size"],
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)

    chimera = chimeraNet(num_bins, **dcnet_conf)
    trainer = PerUttTrainer(chimera, args.alpha, **train_config)
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
Exemplo n.º 2
0
def get_loaders_for_training(args_dataset: tp.Dict, args_loader: tp.Dict,
                             train_file_list: tp.List[str],
                             val_file_list: tp.List[str]):

    # # make dataset
    # train_dataset = SpectrogramDataset_Augmentation(train_file_list, **args_dataset)
    # val_dataset   = SpectrogramDataset_Augmentation(val_file_list, **args_dataset)
    train_dataset = SpectrogramDataset(train_file_list, **args_dataset)
    val_dataset = SpectrogramDataset(val_file_list, **args_dataset)
    # # make dataloader
    train_loader = data.DataLoader(train_dataset, **args_loader["train"])
    val_loader = data.DataLoader(val_dataset, **args_loader["val"])

    return train_loader, val_loader
Exemplo n.º 3
0
def train():
    train_dataset = SpectrogramDataset(C.PATH_FFT)
    valid_dataset = SpectrogramDataset(C.VAL_PATH_FFT)

    train_loader = DataLoader(train_dataset,
                              batch_size=C.BATCH_SIZE,
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=C.BATCH_SIZE,
                              shuffle=True,
                              num_workers=4,
                              drop_last=True,
                              pin_memory=True)

    unet = U_Net()
    trainer = Trainer(unet, C.CHECK_POINT, C.LR)
    trainer.run(train_loader, valid_loader, num_epoches=C.num_epoches)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    SOS_token = 70
    EOS_token = 71
    PAD_token = 0

    device = torch.device('cuda' if args.cuda else 'cpu')

    batch_size = args.batch_size

    train_dataset = SpectrogramDataset(dataset_path=args.dataset_path,
                                       data_list=args.rootpath +
                                       "train_list.csv")
    train_sampler = BucketingSampler(train_dataset, batch_size=batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=4,
                                   batch_sampler=train_sampler)

    test_dataset = SpectrogramDataset(dataset_path=args.dataset_path,
                                      data_list=args.rootpath +
                                      "valid_list.csv")
    test_loader = AudioDataLoader(test_dataset, num_workers=4, batch_size=1)

    input_size = 80
    enc = EncoderRNN(input_size,
                     args.encoder_size,
                     n_layers=args.encoder_layers,
Exemplo n.º 5
0

    # build model, create optimizer
    model = build_model().to(device)
    optimizer = AdamW(model.parameters(),
                           lr=hp.initial_learning_rate, betas=(
        hp.adam_beta1, hp.adam_beta2),
        eps=hp.adam_eps, weight_decay=hp.weight_decay,
        amsgrad=hp.amsgrad)


    # create dataloaders
    with open(os.path.join(data_root, 'spec_info.pkl'), 'rb') as f:
        spec_info = pickle.load(f)
    train_specs = spec_info
    trainset = SpectrogramDataset(data_root, train_specs)
    trainloader = DataLoader(trainset, collate_fn=basic_collate, shuffle=True, num_workers=6, batch_size=32)


    # main train loop
    try:
        losses = train_loop(device, model, trainloader, optimizer)
    except KeyboardInterrupt:
        print("Interrupted!")
        pass
    finally:
        plt.figure()
        x, y = zip(*losses)
        plt.plot(x, y)
        plt.xlabel('Learning Rate')
        plt.ylabel('Loss')
Exemplo n.º 6
0
        
    else:
        model = load_checkpoint(checkpoint_path, model, optimizer, True)
        print("loading model from checkpoint:{}".format(checkpoint_path))
        # set global_test_step to True so we don't evaluate right when we load in the model
        global_test_step = True

    # create dataloaders
    with open(os.path.join(data_root, 'spec_info.pkl'), 'rb') as f:
        spec_info = pickle.load(f)
    test_path = os.path.join(data_root, "test")
    with open(os.path.join(test_path, "test_spec_info.pkl"), 'rb') as f:
        test_spec_info = pickle.load(f)
    test_specs = test_spec_info
    train_specs = spec_info
    trainset = SpectrogramDataset(data_root, train_specs)
    testset = SpectrogramDataset(test_path, test_specs)
    random.shuffle(testset.metadata)
    if hp.validation_size is not None:
        testset.metadata = testset.metadata[:hp.validation_size]
    print(f"Training examples: {len(trainset)}")
    print(f"Validation examples: {len(testset)}")
    trainloader = DataLoader(trainset, collate_fn=basic_collate, shuffle=True, num_workers=2, batch_size=hp.batch_size)
    testloader = DataLoader(testset, collate_fn=basic_collate, shuffle=True, num_workers=2, batch_size=hp.test_batch_size)


    # main train loop
    try:
        train_loop(device, model, trainloader, testloader, optimizer, checkpoint_dir, eval_dir)
    except KeyboardInterrupt:
        print("Interrupted!")