def config_file(tmp_path: Path): # Write default configuration file EnhancementTask.main(cmd=[ "--dry_run", "true", "--output_dir", str(tmp_path), ]) return tmp_path / "config.yaml"
def main(cmd=None): r"""Enhancemnet frontend training. Example: % python enh_train.py asr --print_config --optim adadelta \ > conf/train_enh.yaml % python enh_train.py --config conf/train_enh.yaml """ EnhancementTask.main(cmd=cmd)
def __init__( self, enh_train_config: Union[Path, str], enh_model_file: Union[Path, str] = None, segment_size: Optional[float] = None, hop_size: Optional[float] = None, normalize_segment_scale: bool = False, show_progressbar: bool = False, ref_channel: Optional[int] = None, normalize_output_wav: bool = False, device: str = "cpu", dtype: str = "float32", ): assert check_argument_types() # 1. Build Enh model enh_model, enh_train_args = EnhancementTask.build_model_from_file( enh_train_config, enh_model_file, device ) enh_model.to(dtype=getattr(torch, dtype)).eval() self.device = device self.dtype = dtype self.enh_train_args = enh_train_args self.enh_model = enh_model # only used when processing long speech, i.e. # segment_size is not None and hop_size is not None self.segment_size = segment_size self.hop_size = hop_size self.normalize_segment_scale = normalize_segment_scale self.normalize_output_wav = normalize_output_wav self.show_progressbar = show_progressbar self.num_spk = enh_model.num_spk task = "enhancement" if self.num_spk == 1 else "separation" # reference channel for processing multi-channel speech if ref_channel is not None: logging.info( "Overwrite enh_model.separator.ref_channel with {}".format(ref_channel) ) enh_model.separator.ref_channel = ref_channel self.ref_channel = ref_channel else: self.ref_channel = enh_model.ref_channel self.segmenting = segment_size is not None and hop_size is not None if self.segmenting: logging.info("Perform segment-wise speech %s" % task) logging.info( "Segment length = {} sec, hop length = {} sec".format( segment_size, hop_size ) ) else: logging.info("Perform direct speech %s on the input" % task)
def config_file(tmp_path: Path): # Write default configuration file EnhancementTask.main( cmd=[ "--dry_run", "true", "--output_dir", str(tmp_path / "enh"), ] ) with open(tmp_path / "enh" / "config.yaml", "r") as f: args = yaml.safe_load(f) if args["encoder"] == "stft" and len(args["encoder_conf"]) == 0: args["encoder_conf"] = get_default_kwargs(STFTEncoder) with open(tmp_path / "enh" / "config.yaml", "w") as f: yaml_no_alias_safe_dump(args, f, indent=4, sort_keys=False) return tmp_path / "enh" / "config.yaml"
def inference( output_dir: str, batch_size: int, dtype: str, fs: int, ngpu: int, seed: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], train_config: Optional[str], model_file: Optional[str], model_tag: Optional[str], inference_config: Optional[str], allow_variable_data_keys: bool, segment_size: Optional[float], hop_size: Optional[float], normalize_segment_scale: bool, show_progressbar: bool, ref_channel: Optional[int], normalize_output_wav: bool, enh_s2t_task: bool, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build separate_speech separate_speech_kwargs = dict( train_config=train_config, model_file=model_file, inference_config=inference_config, segment_size=segment_size, hop_size=hop_size, normalize_segment_scale=normalize_segment_scale, show_progressbar=show_progressbar, ref_channel=ref_channel, normalize_output_wav=normalize_output_wav, device=device, dtype=dtype, enh_s2t_task=enh_s2t_task, ) separate_speech = SeparateSpeech.from_pretrained( model_tag=model_tag, **separate_speech_kwargs, ) # 3. Build data-iterator loader = EnhancementTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=EnhancementTask.build_preprocess_fn( separate_speech.enh_train_args, False), collate_fn=EnhancementTask.build_collate_fn( separate_speech.enh_train_args, False), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) # 4. Start for-loop output_dir = Path(output_dir).expanduser().resolve() writers = [] for i in range(separate_speech.num_spk): writers.append( SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp")) for i, (keys, batch) in enumerate(loader): logging.info(f"[{i}] Enhancing {keys}") assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} waves = separate_speech(**batch) for (spk, w) in enumerate(waves): for b in range(batch_size): writers[spk][keys[b]] = fs, w[b] for writer in writers: writer.close()
def get_parser(): parser = EnhancementTask.get_parser() return parser
def inference( output_dir: str, batch_size: int, dtype: str, fs: int, ngpu: int, seed: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], enh_train_config: str, enh_model_file: str, allow_variable_data_keys: bool, normalize_output_wav: bool, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build Enh model enh_model, enh_train_args = EnhancementTask.build_model_from_file( enh_train_config, enh_model_file, device) enh_model.eval() num_spk = enh_model.num_spk # 3. Build data-iterator loader = EnhancementTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=EnhancementTask.build_preprocess_fn( enh_train_args, False), collate_fn=EnhancementTask.build_collate_fn(enh_train_args), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) writers = [] for i in range(num_spk): writers.append( SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp")) for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" with torch.no_grad(): # a. To device batch = to_device(batch, device) # b. Forward Enhancement Frontend waves, _, _ = enh_model.enh_model.forward_rawwav( batch["speech_mix"], batch["speech_mix_lengths"]) assert len(waves[0]) == batch_size, len(waves[0]) # FIXME(Chenda): will be incorrect when # batch size is not 1 or multi-channel case if normalize_output_wav: waves = [ (w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).T.cpu().numpy() for w in waves ] # list[(sample,batch)] else: waves = [w.T.cpu().numpy() for w in waves] for (i, w) in enumerate(waves): writers[i][keys[0]] = fs, w for writer in writers: writer.close()
def test_add_arguments_help(): parser = EnhancementTask.get_parser() with pytest.raises(SystemExit): parser.parse_args(["--help"])
def test_add_arguments(): EnhancementTask.get_parser()
def test_print_config_and_load_it(tmp_path): config_file = tmp_path / "config.yaml" with config_file.open("w") as f: EnhancementTask.print_config(f) parser = EnhancementTask.get_parser() parser.parse_args(["--config", str(config_file)])
def test_main_with_no_args(): with pytest.raises(SystemExit): EnhancementTask.main(cmd=[])
def test_main_print_config(): with pytest.raises(SystemExit): EnhancementTask.main(cmd=["--print_config"])
def test_main_help(): with pytest.raises(SystemExit): EnhancementTask.main(cmd=["--help"])