def main(args): set_seed(args.seed) test_dataset = WaveTestDataset(args.wav_root, args.test_json_path) print("Test dataset includes {} samples.".format(len(test_dataset))) loader = TestDataLoader(test_dataset, batch_size=1, shuffle=False) model = ConvTasNet.build_model(args.model_path) print(model) print("# Parameters: {}".format(model.num_parameters)) if args.use_cuda: if torch.cuda.is_available(): model.cuda() model = nn.DataParallel(model) print("Use CUDA") else: raise ValueError("Cannot use CUDA.") else: print("Does NOT use CUDA") # Criterion if args.criterion == 'sisdr': criterion = NegSISDR() else: raise ValueError("Not support criterion {}".format(args.criterion)) pit_criterion = PIT1d(criterion, n_sources=args.n_sources) tester = Tester(model, loader, pit_criterion, args) tester.run()
def main(args): set_seed(args.seed) loader = {} train_dataset = WaveTrainDataset(args.wav_root, args.train_json_path) valid_dataset = WaveTrainDataset(args.wav_root, args.valid_json_path) print("Training dataset includes {} samples.".format(len(train_dataset))) print("Valid dataset includes {} samples.".format(len(valid_dataset))) loader['train'] = TrainDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) loader['valid'] = TrainDataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False) model = ConvTasNet(args.n_basis, args.kernel_size, stride=args.stride, enc_basis=args.enc_basis, dec_basis=args.dec_basis, enc_nonlinear=args.enc_nonlinear, window_fn=args.window_fn, sep_hidden_channels=args.sep_hidden_channels, sep_bottleneck_channels=args.sep_bottleneck_channels, sep_skip_channels=args.sep_skip_channels, sep_kernel_size=args.sep_kernel_size, sep_num_blocks=args.sep_num_blocks, sep_num_layers=args.sep_num_layers, dilated=args.dilated, separable=args.separable, causal=args.causal, sep_nonlinear=args.sep_nonlinear, sep_norm=args.sep_norm, mask_nonlinear=args.mask_nonlinear, n_sources=args.n_sources) print(model) print("# Parameters: {}".format(model.num_parameters)) if args.use_cuda: if torch.cuda.is_available(): model.cuda() model = nn.DataParallel(model) print("Use CUDA") else: raise ValueError("Cannot use CUDA.") else: print("Does NOT use CUDA") # Optimizer if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: raise ValueError("Not support optimizer {}".format(args.optimizer)) # Criterion if args.criterion == 'sisdr': criterion = NegSISDR() else: raise ValueError("Not support criterion {}".format(args.criterion)) pit_criterion = PIT1d(criterion, n_sources=args.n_sources) trainer = Trainer(model, loader, pit_criterion, optimizer, args) trainer.run()
def main(args): set_seed(args.seed) samples = int(args.sr * args.duration) overlap = samples // 2 max_samples = int(args.sr * args.valid_duration) train_dataset = WaveTrainDataset(args.train_wav_root, args.train_list_path, samples=samples, overlap=overlap, n_sources=args.n_sources) valid_dataset = WaveEvalDataset(args.valid_wav_root, args.valid_list_path, max_samples=max_samples, n_sources=args.n_sources) print("Training dataset includes {} samples.".format(len(train_dataset))) print("Valid dataset includes {} samples.".format(len(valid_dataset))) loader = {} loader['train'] = TrainDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) loader['valid'] = EvalDataLoader(valid_dataset, batch_size=1, shuffle=False) if not args.enc_nonlinear: args.enc_nonlinear = None if args.max_norm is not None and args.max_norm == 0: args.max_norm = None model = DPRNNTasNet(args.n_bases, args.kernel_size, stride=args.stride, enc_bases=args.enc_bases, dec_bases=args.dec_bases, enc_nonlinear=args.enc_nonlinear, window_fn=args.window_fn, sep_hidden_channels=args.sep_hidden_channels, sep_bottleneck_channels=args.sep_bottleneck_channels, sep_chunk_size=args.sep_chunk_size, sep_hop_size=args.sep_hop_size, sep_num_blocks=args.sep_num_blocks, causal=args.causal, sep_norm=args.sep_norm, mask_nonlinear=args.mask_nonlinear, n_sources=args.n_sources) print(model) print("# Parameters: {}".format(model.num_parameters)) if args.use_cuda: if torch.cuda.is_available(): model.cuda() model = nn.DataParallel(model) print("Use CUDA") else: raise ValueError("Cannot use CUDA.") else: print("Does NOT use CUDA") # Optimizer if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: raise ValueError("Not support optimizer {}".format(args.optimizer)) # Criterion if args.criterion == 'sisdr': criterion = NegSISDR() else: raise ValueError("Not support criterion {}".format(args.criterion)) pit_criterion = PIT1d(criterion, n_sources=args.n_sources) trainer = AdhocTrainer(model, loader, pit_criterion, optimizer, args) trainer.run()