Exemplo n.º 1
0
def main(args):
    set_seed(args.seed)
    
    test_dataset = WaveTestDataset(args.wav_root, args.test_json_path)
    print("Test dataset includes {} samples.".format(len(test_dataset)))
    
    loader = TestDataLoader(test_dataset, batch_size=1, shuffle=False)
    
    model = ConvTasNet.build_model(args.model_path)
    print(model)
    print("# Parameters: {}".format(model.num_parameters))
    
    if args.use_cuda:
        if torch.cuda.is_available():
            model.cuda()
            model = nn.DataParallel(model)
            print("Use CUDA")
        else:
            raise ValueError("Cannot use CUDA.")
    else:
        print("Does NOT use CUDA")
    
    # Criterion
    if args.criterion == 'sisdr':
        criterion = NegSISDR()
    else:
        raise ValueError("Not support criterion {}".format(args.criterion))
    
    pit_criterion = PIT1d(criterion, n_sources=args.n_sources)
    
    tester = Tester(model, loader, pit_criterion, args)
    tester.run()
Exemplo n.º 2
0
def main(args):
    set_seed(args.seed)
    
    loader = {}
    
    train_dataset = WaveTrainDataset(args.wav_root, args.train_json_path)
    valid_dataset = WaveTrainDataset(args.wav_root, args.valid_json_path)
    print("Training dataset includes {} samples.".format(len(train_dataset)))
    print("Valid dataset includes {} samples.".format(len(valid_dataset)))
    
    loader['train'] = TrainDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    loader['valid'] = TrainDataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False)
    
    model = ConvTasNet(args.n_basis, args.kernel_size, stride=args.stride, enc_basis=args.enc_basis, dec_basis=args.dec_basis, enc_nonlinear=args.enc_nonlinear, window_fn=args.window_fn, sep_hidden_channels=args.sep_hidden_channels, sep_bottleneck_channels=args.sep_bottleneck_channels, sep_skip_channels=args.sep_skip_channels, sep_kernel_size=args.sep_kernel_size, sep_num_blocks=args.sep_num_blocks, sep_num_layers=args.sep_num_layers, dilated=args.dilated, separable=args.separable, causal=args.causal, sep_nonlinear=args.sep_nonlinear, sep_norm=args.sep_norm, mask_nonlinear=args.mask_nonlinear, n_sources=args.n_sources)
    print(model)
    print("# Parameters: {}".format(model.num_parameters))
    
    if args.use_cuda:
        if torch.cuda.is_available():
            model.cuda()
            model = nn.DataParallel(model)
            print("Use CUDA")
        else:
            raise ValueError("Cannot use CUDA.")
    else:
        print("Does NOT use CUDA")
        
    # Optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError("Not support optimizer {}".format(args.optimizer))
    
    # Criterion
    if args.criterion == 'sisdr':
        criterion = NegSISDR()
    else:
        raise ValueError("Not support criterion {}".format(args.criterion))
    
    pit_criterion = PIT1d(criterion, n_sources=args.n_sources)
    
    trainer = Trainer(model, loader, pit_criterion, optimizer, args)
    trainer.run()
Exemplo n.º 3
0
def main(args):
    set_seed(args.seed)

    samples = int(args.sr * args.duration)
    overlap = samples // 2
    max_samples = int(args.sr * args.valid_duration)

    train_dataset = WaveTrainDataset(args.train_wav_root,
                                     args.train_list_path,
                                     samples=samples,
                                     overlap=overlap,
                                     n_sources=args.n_sources)
    valid_dataset = WaveEvalDataset(args.valid_wav_root,
                                    args.valid_list_path,
                                    max_samples=max_samples,
                                    n_sources=args.n_sources)
    print("Training dataset includes {} samples.".format(len(train_dataset)))
    print("Valid dataset includes {} samples.".format(len(valid_dataset)))

    loader = {}
    loader['train'] = TrainDataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True)
    loader['valid'] = EvalDataLoader(valid_dataset,
                                     batch_size=1,
                                     shuffle=False)

    if not args.enc_nonlinear:
        args.enc_nonlinear = None
    if args.max_norm is not None and args.max_norm == 0:
        args.max_norm = None
    model = DPRNNTasNet(args.n_bases,
                        args.kernel_size,
                        stride=args.stride,
                        enc_bases=args.enc_bases,
                        dec_bases=args.dec_bases,
                        enc_nonlinear=args.enc_nonlinear,
                        window_fn=args.window_fn,
                        sep_hidden_channels=args.sep_hidden_channels,
                        sep_bottleneck_channels=args.sep_bottleneck_channels,
                        sep_chunk_size=args.sep_chunk_size,
                        sep_hop_size=args.sep_hop_size,
                        sep_num_blocks=args.sep_num_blocks,
                        causal=args.causal,
                        sep_norm=args.sep_norm,
                        mask_nonlinear=args.mask_nonlinear,
                        n_sources=args.n_sources)
    print(model)
    print("# Parameters: {}".format(model.num_parameters))

    if args.use_cuda:
        if torch.cuda.is_available():
            model.cuda()
            model = nn.DataParallel(model)
            print("Use CUDA")
        else:
            raise ValueError("Cannot use CUDA.")
    else:
        print("Does NOT use CUDA")

    # Optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    else:
        raise ValueError("Not support optimizer {}".format(args.optimizer))

    # Criterion
    if args.criterion == 'sisdr':
        criterion = NegSISDR()
    else:
        raise ValueError("Not support criterion {}".format(args.criterion))

    pit_criterion = PIT1d(criterion, n_sources=args.n_sources)

    trainer = AdhocTrainer(model, loader, pit_criterion, optimizer, args)
    trainer.run()