def run(args): num_bins, config_dict = parse_yaml(args.config) dataloader_conf = config_dict["dataloader"] spectrogram_conf = config_dict["spectrogram_reader"] # Load cmvn dict_mvn = dataloader_conf["mvn_dict"] if dict_mvn: if not os.path.exists(dict_mvn): raise FileNotFoundError("Could not find mvn files") with open(dict_mvn, "rb") as f: dict_mvn = pickle.load(f) # default: True apply_log = dataloader_conf[ "apply_log"] if "apply_log" in dataloader_conf else True dcnet = PITNet(num_bins, **config_dict["model"]) frame_length = spectrogram_conf["frame_length"] frame_shift = spectrogram_conf["frame_shift"] window = spectrogram_conf["window"] separator = Separator(dcnet, args.state_dict, cuda=args.cuda) utt_dict = parse_scps(args.wave_scp) num_utts = 0 for key, utt in utt_dict.items(): try: samps, stft_mat = stft(utt, frame_length=frame_length, frame_shift=frame_shift, window=window, center=True, return_samps=True) except FileNotFoundError: print("Skip utterance {}... not found".format(key)) continue print("Processing utterance {}".format(key)) num_utts += 1 norm = np.linalg.norm(samps, np.inf) spk_mask, spk_spectrogram = separator.seperate(stft_mat, cmvn=dict_mvn, apply_log=apply_log) for index, stft_mat in enumerate(spk_spectrogram): istft(os.path.join(args.dump_dir, '{}.spk{}.wav'.format(key, index + 1)), stft_mat, frame_length=frame_length, frame_shift=frame_shift, window=window, center=True, norm=norm, fs=8000, nsamps=samps.size) if args.dump_mask: sio.savemat( os.path.join(args.dump_dir, '{}.spk{}.mat'.format(key, index + 1)), {"mask": spk_mask[index]}) print("Processed {} utterance!".format(num_utts))
def train(args): gpuid = tuple(map(int, args.gpus.split(','))) debug = args.debug logger.info( "Start training in {} model".format('debug' if debug else 'normal')) num_bins, config_dict = parse_yaml(args.config) reader_conf = config_dict["spectrogram_reader"] loader_conf = config_dict["dataloader"] dcnnet_conf = config_dict["model"] logger.info("Training with {}".format( "IRM" if reader_conf["apply_abs"] else "PSM")) batch_size = loader_conf["batch_size"] logger.info( "Training in {}".format("per utterance" if batch_size == 1 else '{} utterance per batch'.format(batch_size))) train_loader = uttloader(config_dict["train_scp_conf"] if not debug else config_dict["debug_scp_conf"], reader_conf, loader_conf, train=True) valid_loader = uttloader(config_dict["valid_scp_conf"] if not debug else config_dict["debug_scp_conf"], reader_conf, loader_conf, train=False) checkpoint = config_dict["trainer"]["checkpoint"] logger.info("Training for {} epoches -> {}...".format( args.num_epoches, "default checkpoint" if checkpoint is None else checkpoint)) nnet = PITNet(num_bins, **dcnnet_conf) trainer = PITrainer(nnet, **config_dict["trainer"], gpuid=gpuid) trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
def train(args): debug = args.debug logger.info( "Start training in {} model".format('debug' if debug else 'normal')) num_bins, config_dict = parse_yaml(args.config) reader_conf = config_dict["spectrogram_reader"] loader_conf = config_dict["dataloader"] dcnnet_conf = config_dict["model"] state_dict = args.state_dict location = "cpu" if args.cpu else None logger.info("Training with {}".format("IRM" if reader_conf["apply_abs"] else "PSM")) batch_size = loader_conf["batch_size"] logger.info( "Training in {}".format("per utterance" if batch_size == 1 else '{} utterance per batch'.format(batch_size))) train_loader = uttloader( config_dict["train_scp_conf"] if not debug else config_dict["debug_scp_conf"], reader_conf, loader_conf, train=True) valid_loader = uttloader( config_dict["valid_scp_conf"] if not debug else config_dict["debug_scp_conf"], reader_conf, loader_conf, train=False) checkpoint = config_dict["trainer"]["checkpoint"] logger.info("Training for {} epoches -> {}...".format( args.num_epoches, "default checkpoint" if checkpoint is None else checkpoint)) nnet = PITNet(num_bins, **dcnnet_conf) if(state_dict != ""): if not os.path.exists(state_dict): raise ValueError("there is no path {}".format(state_dict)) else: logger.info("load {}".format(state_dict)) nnet.load_state_dict(th.load(state_dict, map_location=location)) trainer = PITrainer(nnet, **config_dict["trainer"]) trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches, start=args.start)