Exemple #1
0
def main(args):
    set_seed(args.seed)
    
    test_dataset = WaveTestDataset(args.wav_root, args.test_json_path)
    print("Test dataset includes {} samples.".format(len(test_dataset)))
    
    loader = TestDataLoader(test_dataset, batch_size=1, shuffle=False)
    
    model = ConvTasNet.build_model(args.model_path)
    print(model)
    print("# Parameters: {}".format(model.num_parameters))
    
    if args.use_cuda:
        if torch.cuda.is_available():
            model.cuda()
            model = nn.DataParallel(model)
            print("Use CUDA")
        else:
            raise ValueError("Cannot use CUDA.")
    else:
        print("Does NOT use CUDA")
    
    # Criterion
    if args.criterion == 'sisdr':
        criterion = NegSISDR()
    else:
        raise ValueError("Not support criterion {}".format(args.criterion))
    
    pit_criterion = PIT1d(criterion, n_sources=args.n_sources)
    
    tester = Tester(model, loader, pit_criterion, args)
    tester.run()
def main(args):
    set_seed(args.seed)
    
    loader = {}
    
    train_dataset = WaveTrainDataset(args.wav_root, args.train_json_path)
    valid_dataset = WaveTrainDataset(args.wav_root, args.valid_json_path)
    print("Training dataset includes {} samples.".format(len(train_dataset)))
    print("Valid dataset includes {} samples.".format(len(valid_dataset)))
    
    loader['train'] = TrainDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    loader['valid'] = TrainDataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False)
    
    model = ConvTasNet(args.n_basis, args.kernel_size, stride=args.stride, enc_basis=args.enc_basis, dec_basis=args.dec_basis, enc_nonlinear=args.enc_nonlinear, window_fn=args.window_fn, sep_hidden_channels=args.sep_hidden_channels, sep_bottleneck_channels=args.sep_bottleneck_channels, sep_skip_channels=args.sep_skip_channels, sep_kernel_size=args.sep_kernel_size, sep_num_blocks=args.sep_num_blocks, sep_num_layers=args.sep_num_layers, dilated=args.dilated, separable=args.separable, causal=args.causal, sep_nonlinear=args.sep_nonlinear, sep_norm=args.sep_norm, mask_nonlinear=args.mask_nonlinear, n_sources=args.n_sources)
    print(model)
    print("# Parameters: {}".format(model.num_parameters))
    
    if args.use_cuda:
        if torch.cuda.is_available():
            model.cuda()
            model = nn.DataParallel(model)
            print("Use CUDA")
        else:
            raise ValueError("Cannot use CUDA.")
    else:
        print("Does NOT use CUDA")
        
    # Optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError("Not support optimizer {}".format(args.optimizer))
    
    # Criterion
    if args.criterion == 'sisdr':
        criterion = NegSISDR()
    else:
        raise ValueError("Not support criterion {}".format(args.criterion))
    
    pit_criterion = PIT1d(criterion, n_sources=args.n_sources)
    
    trainer = Trainer(model, loader, pit_criterion, optimizer, args)
    trainer.run()
Exemple #3
0
def main(args):
    set_seed(args.seed)

    samples = int(args.sr * args.duration)
    overlap = samples // 2
    max_samples = int(args.sr * args.valid_duration)
    max_n_sources = max([int(number) for number in args.n_sources.split('+')])

    train_dataset = MixedNumberSourcesWaveTrainDataset(
        args.train_wav_root,
        args.train_list_path,
        samples=samples,
        overlap=overlap,
        max_n_sources=max_n_sources)
    valid_dataset = MixedNumberSourcesWaveEvalDataset(
        args.valid_wav_root,
        args.valid_list_path,
        max_samples=max_samples,
        max_n_sources=max_n_sources)
    print("Training dataset includes {} samples.".format(len(train_dataset)))
    print("Valid dataset includes {} samples.".format(len(valid_dataset)))

    loader = {}
    loader['train'] = MixedNumberSourcesTrainDataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True)
    loader['valid'] = MixedNumberSourcesEvalDataLoader(valid_dataset,
                                                       batch_size=1,
                                                       shuffle=False)

    if not args.enc_nonlinear:
        args.enc_nonlinear = None
    model = ConvTasNet(args.n_bases,
                       args.kernel_size,
                       stride=args.stride,
                       enc_bases=args.enc_bases,
                       dec_bases=args.dec_bases,
                       enc_nonlinear=args.enc_nonlinear,
                       window_fn=args.window_fn,
                       sep_hidden_channels=args.sep_hidden_channels,
                       sep_bottleneck_channels=args.sep_bottleneck_channels,
                       sep_skip_channels=args.sep_skip_channels,
                       sep_kernel_size=args.sep_kernel_size,
                       sep_num_blocks=args.sep_num_blocks,
                       sep_num_layers=args.sep_num_layers,
                       dilated=args.dilated,
                       separable=args.separable,
                       causal=args.causal,
                       sep_nonlinear=args.sep_nonlinear,
                       sep_norm=args.sep_norm,
                       mask_nonlinear=args.mask_nonlinear,
                       n_sources=2)
    print(model)
    print("# Parameters: {}".format(model.num_parameters))

    if args.use_cuda:
        if torch.cuda.is_available():
            model.cuda()
            model = nn.DataParallel(model)
            print("Use CUDA")
        else:
            raise ValueError("Cannot use CUDA.")
    else:
        print("Does NOT use CUDA")

    # Optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    else:
        raise ValueError("Not support optimizer {}".format(args.optimizer))

    # Criterion
    if args.criterion == 'sisdr':
        criterion = NegSISDR()
    else:
        raise ValueError("Not support criterion {}".format(args.criterion))

    pit_criterion = ORPIT(criterion)

    trainer = AdhocTrainer(model, loader, pit_criterion, optimizer, args)
    trainer.run()