def test_freq_xfmr_rel(num_spks): nnet_cls = aps_sse_nnet("sse@freq_xfmr_rel") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) xfmr = nnet_cls(input_size=257, enh_transform=transform, num_spks=num_spks, num_bins=257, att_dim=256, nhead=4, radius=256, feedforward_dim=512, att_dropout=0.1, proj_dropout=0.1, post_norm=True, num_layers=3, non_linear="sigmoid", training_mode="time") inp = th.rand(4, 64000) x = xfmr(inp) if num_spks > 1: assert len(x) == num_spks assert x[0].shape == th.Size([4, 64000]) else: assert x.shape == th.Size([4, 64000]) y = xfmr.infer(inp[1]) if num_spks > 1: y = y[0] assert y.shape == th.Size([64000])
def test_dense_unet(num_spks, non_linear): nnet_cls = aps_sse_nnet("sse@dense_unet") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) dense_unet = nnet_cls(K="3,3;3,3;3,3;3,3;3,3;3,3;3,3;3,3", S="1,1;2,1;2,1;2,1;2,1;2,1;2,1;2,1", P="0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1", O="0,0,0,0,0,0,0,0", enc_channel="16,32,32,32,32,64,128,384", dec_channel="32,16,32,32,32,32,64,128", conv_dropout=0.3, num_spks=num_spks, rnn_hidden=512, rnn_layers=2, rnn_resize=384, rnn_bidir=False, rnn_dropout=0.2, num_dense_blocks=5, enh_transform=transform, non_linear=non_linear, inp_cplx=True, out_cplx=True, training_mode="time") inp = th.rand(4, 64000) x = dense_unet(inp) if num_spks > 1: x = x[0] assert x.shape == th.Size([4, 64000]) y = dense_unet.infer(inp[1]) if num_spks > 1: y = y[0] assert y.shape == th.Size([64000])
def toy_rnn(mode, num_spks): transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) base_rnn_cls = aps_sse_nnet("sse@base_rnn") return base_rnn_cls(enh_transform=transform, num_bins=257, input_size=257, num_layers=2, num_spks=num_spks, hidden=256, training_mode=mode)
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf = load_ss_conf(args.conf) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) data_conf = conf["data_conf"] load_conf = { "fmt": data_conf["fmt"], "batch_size": args.batch_size, "num_workers": args.num_workers, } load_conf.update(data_conf["loader"]) trn_loader = aps_dataloader(train=True, **load_conf, **data_conf["train"]) dev_loader = aps_dataloader(train=False, **load_conf, **data_conf["valid"]) sse_cls = aps_sse_nnet(conf["nnet"]) if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"]) else: nnet = sse_cls(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) Trainer = aps_trainer(args.trainer, distributed=False) trainer = Trainer(task, device_ids=args.device_id, checkpoint=args.checkpoint, resume=args.resume, init=args.init, save_interval=args.save_interval, prog_interval=args.prog_interval, tensorboard=args.tensorboard, **conf["trainer_conf"]) # dump configurations conf["cmd_args"] = vars(args) with open(f"{args.checkpoint}/train.yaml", "w") as f: yaml.dump(conf, f) trainer.run(trn_loader, dev_loader, num_epochs=args.epochs, eval_interval=args.eval_interval)
def test_enh_ml(num_channels): nnet_cls = aps_sse_nnet("sse@rnn_enh_ml") transform = EnhTransform(feats="spectrogram-log-cmvn-ipd", frame_len=512, frame_hop=256, ipd_index="0,1;0,2") rnn_ml = nnet_cls(enh_transform=transform, num_bins=257, input_size=257 * 3, input_proj=512, num_layers=2, hidden=512) task = aps_task("sse@enh_ml", rnn_ml) egs = {"mix": th.rand(4, num_channels, 64000)} run_epochs(task, egs, 3)
def test_crn(): nnet_cls = aps_sse_nnet("sse@crn") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=320, frame_hop=160, round_pow_of_two=False) crn = nnet_cls(161, enh_transform=transform, mode="masking", training_mode="freq") inp = th.rand(4, 64000) x = crn(inp) assert x.shape == th.Size([4, 161, 399]) z = crn.infer(inp[1]) assert z.shape == th.Size([64000])
def test_phasen(): nnet_cls = aps_sse_nnet("sse@phasen") transform = EnhTransform(feats="", frame_len=512, frame_hop=256) phasen = nnet_cls(12, 4, enh_transform=transform, num_tsbs=1, num_bins=257, channel_r=5, conv1d_kernel=9, lstm_hidden=256, linear_size=512) inp = th.rand(4, 64000) x, y = phasen(inp) assert x.shape == th.Size([4, 257, 249]) assert y.shape == th.Size([4, 257, 249]) z = phasen.infer(inp[1]) assert z.shape == th.Size([64000])
def test_dprnn(): nnet_cls = aps_sse_nnet("sse@time_dprnn") dprnn = nnet_cls(num_spks=1, input_norm="cLN", block_type="dp", conv_kernels=16, conv_filters=64, proj_filters=64, chunk_len=100, num_layers=2, rnn_hidden=64, rnn_bi_inter=True, non_linear="relu") inp = th.rand(4, 64000) x = dprnn(inp) assert x.shape == th.Size([4, 64000]) y = dprnn.infer(inp[1]) assert y.shape == th.Size([64000])
def test_rnn_enh_ml(num_bins): nnet_cls = aps_sse_nnet("sse@rnn_enh_ml") transform = EnhTransform(feats="spectrogram-log-cmvn-ipd", frame_len=512, frame_hop=256, ipd_index="0,1;0,2;0,3") rnn_enh_ml = nnet_cls(enh_transform=transform, num_bins=num_bins, input_size=num_bins * 4, input_proj=512, num_layers=2, hidden=512) inp = th.rand(2, 5, 64000) x, y = rnn_enh_ml(inp) assert x.shape == th.Size([2, 5, num_bins, 249]) assert th.isnan(x.real).sum() + th.isnan(x.imag).sum() == 0 assert y.shape == th.Size([2, 249, num_bins]) z = rnn_enh_ml.infer(inp[0]) assert z.shape == th.Size([249, num_bins])
def run(args): # set random seed seed = set_seed(args.seed) if seed is not None: print(f"Set random seed as {seed}") conf = load_ss_conf(args.conf) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) sse_cls = aps_sse_nnet(conf["nnet"]) # with or without enh_tranform if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"]) else: nnet = sse_cls(**conf["nnet_conf"]) task = aps_task(conf["task"], nnet, **conf["task_conf"]) train_worker(task, conf, args)
def run(args): _ = set_seed(args.seed) conf = load_ss_conf(args.conf) _ = set_seed(args.seed) print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True) print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True) sse_cls = aps_sse_nnet(conf["nnet"]) # with or without enh_tranform if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"]) else: nnet = sse_cls(**conf["nnet_conf"]) is_distributed = args.distributed != "none" if is_distributed and args.eval_interval <= 0: raise RuntimeError("For distributed training of SE/SS model, " "--eval-interval must be larger than 0") start_trainer(args.trainer, conf, nnet, args, reduction_tag="none")
def test_base_rnn(num_spks, nonlinear): nnet_cls = aps_sse_nnet("sse@base_rnn") transform = EnhTransform(feats="spectrogram-log-cmvn", frame_len=512, frame_hop=256) base_rnn = nnet_cls(enh_transform=transform, num_bins=257, input_size=257, input_proj=512, num_layers=2, hidden=512, num_spks=num_spks, output_nonlinear=nonlinear) inp = th.rand(2, 64000) x = base_rnn(inp) if num_spks > 1: x = x[0] assert x.shape == th.Size([2, 257, 249]) z = base_rnn.infer(inp[0]) if num_spks > 1: z = z[0] assert z.shape == th.Size([64000])
def _load(self, cpt_dir: str, cpt_tag: str = "best", task: str = "asr") -> Tuple[int, nn.Module, Dict]: if task not in ["asr", "sse"]: raise ValueError(f"Unknown task name: {task}") cpt_dir = pathlib.Path(cpt_dir) # load checkpoint cpt = th.load(cpt_dir / f"{cpt_tag}.pt.tar", map_location="cpu") with open(cpt_dir / "train.yaml", "r") as f: conf = yaml.full_load(f) if task == "asr": net_cls = aps_asr_nnet(conf["nnet"]) else: net_cls = aps_sse_nnet(conf["nnet"]) asr_transform = None enh_transform = None self.accept_raw = False if "asr_transform" in conf: asr_transform = aps_transform("asr")(**conf["asr_transform"]) # if no STFT layer self.accept_raw = asr_transform.spectra_index != -1 if "enh_transform" in conf: enh_transform = aps_transform("enh")(**conf["enh_transform"]) self.accept_raw = True if enh_transform and asr_transform: nnet = net_cls(enh_transform=enh_transform, asr_transform=asr_transform, **conf["nnet_conf"]) elif asr_transform: nnet = net_cls(asr_transform=asr_transform, **conf["nnet_conf"]) elif enh_transform: nnet = net_cls(enh_transform=enh_transform, **conf["nnet_conf"]) else: nnet = net_cls(**conf["nnet_conf"]) nnet.load_state_dict(cpt["model_state"]) return cpt["epoch"], nnet, conf
def test_dcunet(num_branch, cplx): nnet_cls = aps_sse_nnet("sse@dcunet") transform = EnhTransform(feats="", frame_len=512, frame_hop=256) dcunet = nnet_cls(enh_transform=transform, K="7,5;7,5;5,3;5,3;3,3;3,3", S="2,1;2,1;2,1;2,1;2,1;2,1", C="32,32,64,64,64,128", P="1,1,1,1,1,0", O="0,0,1,1,1,0", num_branch=num_branch, cplx=cplx, causal_conv=False, freq_padding=True, connection="cat") inp = th.rand(4, 64000) x = dcunet(inp) if num_branch > 1: x = x[0] assert x.shape == th.Size([4, 64000]) y = dcunet.infer(inp[1]) if num_branch > 1: y = y[0] assert y.shape == th.Size([64000])
def test_dccrn(num_spks, cplx): nnet_cls = aps_sse_nnet("sse@dccrn") transform = EnhTransform(feats="spectrogram", frame_len=512, frame_hop=256) dccrn = nnet_cls(enh_transform=transform, cplx=cplx, K="3,3;3,3;3,3;3,3;3,3;3,3;3,3", S="2,1;2,1;2,1;2,1;2,1;2,1;2,1", P="1,1,1,1,1,0,0", O="0,0,0,0,0,0,1", C="16,32,64,64,128,128,256", num_spks=num_spks, rnn_resize=512 if cplx else 256, non_linear="sigmoid", connection="cat") inp = th.rand(4, 64000) x = dccrn(inp) if num_spks > 1: x = x[0] assert x.shape == th.Size([4, 64000]) y = dccrn.infer(inp[1]) if num_spks > 1: y = y[0] assert y.shape == th.Size([64000])
def test_tasnet(num_spks, nonlinear): nnet_cls = aps_sse_nnet("sse@time_tasnet") tasnet = nnet_cls(L=40, N=256, X=8, R=4, B=256, H=512, P=3, input_norm="gLN", norm="BN", num_spks=num_spks, non_linear=nonlinear, block_residual=True, causal=False) inp = th.rand(4, 64000) x = tasnet(inp) if num_spks > 1: x = x[0] assert x.shape == th.Size([4, 64000]) y = tasnet.infer(inp[1]) if num_spks > 1: y = y[0] assert y.shape == th.Size([64000])