Ejemplo n.º 1
0
def test_freq_xfmr_rel(num_spks):
    nnet_cls = aps_sse_nnet("sse@freq_xfmr_rel")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    xfmr = nnet_cls(input_size=257,
                    enh_transform=transform,
                    num_spks=num_spks,
                    num_bins=257,
                    att_dim=256,
                    nhead=4,
                    radius=256,
                    feedforward_dim=512,
                    att_dropout=0.1,
                    proj_dropout=0.1,
                    post_norm=True,
                    num_layers=3,
                    non_linear="sigmoid",
                    training_mode="time")
    inp = th.rand(4, 64000)
    x = xfmr(inp)
    if num_spks > 1:
        assert len(x) == num_spks
        assert x[0].shape == th.Size([4, 64000])
    else:
        assert x.shape == th.Size([4, 64000])
    y = xfmr.infer(inp[1])
    if num_spks > 1:
        y = y[0]
    assert y.shape == th.Size([64000])
Ejemplo n.º 2
0
def test_dense_unet(num_spks, non_linear):
    nnet_cls = aps_sse_nnet("sse@dense_unet")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    dense_unet = nnet_cls(K="3,3;3,3;3,3;3,3;3,3;3,3;3,3;3,3",
                          S="1,1;2,1;2,1;2,1;2,1;2,1;2,1;2,1",
                          P="0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1;0,1",
                          O="0,0,0,0,0,0,0,0",
                          enc_channel="16,32,32,32,32,64,128,384",
                          dec_channel="32,16,32,32,32,32,64,128",
                          conv_dropout=0.3,
                          num_spks=num_spks,
                          rnn_hidden=512,
                          rnn_layers=2,
                          rnn_resize=384,
                          rnn_bidir=False,
                          rnn_dropout=0.2,
                          num_dense_blocks=5,
                          enh_transform=transform,
                          non_linear=non_linear,
                          inp_cplx=True,
                          out_cplx=True,
                          training_mode="time")
    inp = th.rand(4, 64000)
    x = dense_unet(inp)
    if num_spks > 1:
        x = x[0]
    assert x.shape == th.Size([4, 64000])
    y = dense_unet.infer(inp[1])
    if num_spks > 1:
        y = y[0]
    assert y.shape == th.Size([64000])
Ejemplo n.º 3
0
def toy_rnn(mode, num_spks):
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    base_rnn_cls = aps_sse_nnet("sse@base_rnn")
    return base_rnn_cls(enh_transform=transform,
                        num_bins=257,
                        input_size=257,
                        num_layers=2,
                        num_spks=num_spks,
                        hidden=256,
                        training_mode=mode)
Ejemplo n.º 4
0
Archivo: train_ss.py Proyecto: xmpx/aps
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf = load_ss_conf(args.conf)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    data_conf = conf["data_conf"]
    load_conf = {
        "fmt": data_conf["fmt"],
        "batch_size": args.batch_size,
        "num_workers": args.num_workers,
    }
    load_conf.update(data_conf["loader"])
    trn_loader = aps_dataloader(train=True, **load_conf, **data_conf["train"])
    dev_loader = aps_dataloader(train=False, **load_conf, **data_conf["valid"])

    sse_cls = aps_sse_nnet(conf["nnet"])
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])
        nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"])
    else:
        nnet = sse_cls(**conf["nnet_conf"])

    task = aps_task(conf["task"], nnet, **conf["task_conf"])

    Trainer = aps_trainer(args.trainer, distributed=False)
    trainer = Trainer(task,
                      device_ids=args.device_id,
                      checkpoint=args.checkpoint,
                      resume=args.resume,
                      init=args.init,
                      save_interval=args.save_interval,
                      prog_interval=args.prog_interval,
                      tensorboard=args.tensorboard,
                      **conf["trainer_conf"])

    # dump configurations
    conf["cmd_args"] = vars(args)
    with open(f"{args.checkpoint}/train.yaml", "w") as f:
        yaml.dump(conf, f)

    trainer.run(trn_loader,
                dev_loader,
                num_epochs=args.epochs,
                eval_interval=args.eval_interval)
Ejemplo n.º 5
0
def test_enh_ml(num_channels):
    nnet_cls = aps_sse_nnet("sse@rnn_enh_ml")
    transform = EnhTransform(feats="spectrogram-log-cmvn-ipd",
                             frame_len=512,
                             frame_hop=256,
                             ipd_index="0,1;0,2")
    rnn_ml = nnet_cls(enh_transform=transform,
                      num_bins=257,
                      input_size=257 * 3,
                      input_proj=512,
                      num_layers=2,
                      hidden=512)
    task = aps_task("sse@enh_ml", rnn_ml)
    egs = {"mix": th.rand(4, num_channels, 64000)}
    run_epochs(task, egs, 3)
Ejemplo n.º 6
0
def test_crn():
    nnet_cls = aps_sse_nnet("sse@crn")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=320,
                             frame_hop=160,
                             round_pow_of_two=False)
    crn = nnet_cls(161,
                   enh_transform=transform,
                   mode="masking",
                   training_mode="freq")
    inp = th.rand(4, 64000)
    x = crn(inp)
    assert x.shape == th.Size([4, 161, 399])
    z = crn.infer(inp[1])
    assert z.shape == th.Size([64000])
Ejemplo n.º 7
0
def test_phasen():
    nnet_cls = aps_sse_nnet("sse@phasen")
    transform = EnhTransform(feats="", frame_len=512, frame_hop=256)
    phasen = nnet_cls(12,
                      4,
                      enh_transform=transform,
                      num_tsbs=1,
                      num_bins=257,
                      channel_r=5,
                      conv1d_kernel=9,
                      lstm_hidden=256,
                      linear_size=512)
    inp = th.rand(4, 64000)
    x, y = phasen(inp)
    assert x.shape == th.Size([4, 257, 249])
    assert y.shape == th.Size([4, 257, 249])
    z = phasen.infer(inp[1])
    assert z.shape == th.Size([64000])
Ejemplo n.º 8
0
def test_dprnn():
    nnet_cls = aps_sse_nnet("sse@time_dprnn")
    dprnn = nnet_cls(num_spks=1,
                     input_norm="cLN",
                     block_type="dp",
                     conv_kernels=16,
                     conv_filters=64,
                     proj_filters=64,
                     chunk_len=100,
                     num_layers=2,
                     rnn_hidden=64,
                     rnn_bi_inter=True,
                     non_linear="relu")
    inp = th.rand(4, 64000)
    x = dprnn(inp)
    assert x.shape == th.Size([4, 64000])
    y = dprnn.infer(inp[1])
    assert y.shape == th.Size([64000])
Ejemplo n.º 9
0
def test_rnn_enh_ml(num_bins):
    nnet_cls = aps_sse_nnet("sse@rnn_enh_ml")
    transform = EnhTransform(feats="spectrogram-log-cmvn-ipd",
                             frame_len=512,
                             frame_hop=256,
                             ipd_index="0,1;0,2;0,3")
    rnn_enh_ml = nnet_cls(enh_transform=transform,
                          num_bins=num_bins,
                          input_size=num_bins * 4,
                          input_proj=512,
                          num_layers=2,
                          hidden=512)
    inp = th.rand(2, 5, 64000)
    x, y = rnn_enh_ml(inp)
    assert x.shape == th.Size([2, 5, num_bins, 249])
    assert th.isnan(x.real).sum() + th.isnan(x.imag).sum() == 0
    assert y.shape == th.Size([2, 249, num_bins])
    z = rnn_enh_ml.infer(inp[0])
    assert z.shape == th.Size([249, num_bins])
Ejemplo n.º 10
0
def run(args):
    # set random seed
    seed = set_seed(args.seed)
    if seed is not None:
        print(f"Set random seed as {seed}")

    conf = load_ss_conf(args.conf)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    sse_cls = aps_sse_nnet(conf["nnet"])
    # with or without enh_tranform
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])
        nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"])
    else:
        nnet = sse_cls(**conf["nnet_conf"])

    task = aps_task(conf["task"], nnet, **conf["task_conf"])
    train_worker(task, conf, args)
Ejemplo n.º 11
0
def run(args):
    _ = set_seed(args.seed)
    conf = load_ss_conf(args.conf)

    _ = set_seed(args.seed)
    print(f"Arguments in args:\n{pprint.pformat(vars(args))}", flush=True)
    print(f"Arguments in yaml:\n{pprint.pformat(conf)}", flush=True)

    sse_cls = aps_sse_nnet(conf["nnet"])
    # with or without enh_tranform
    if "enh_transform" in conf:
        enh_transform = aps_transform("enh")(**conf["enh_transform"])
        nnet = sse_cls(enh_transform=enh_transform, **conf["nnet_conf"])
    else:
        nnet = sse_cls(**conf["nnet_conf"])

    is_distributed = args.distributed != "none"
    if is_distributed and args.eval_interval <= 0:
        raise RuntimeError("For distributed training of SE/SS model, "
                           "--eval-interval must be larger than 0")
    start_trainer(args.trainer, conf, nnet, args, reduction_tag="none")
Ejemplo n.º 12
0
def test_base_rnn(num_spks, nonlinear):
    nnet_cls = aps_sse_nnet("sse@base_rnn")
    transform = EnhTransform(feats="spectrogram-log-cmvn",
                             frame_len=512,
                             frame_hop=256)
    base_rnn = nnet_cls(enh_transform=transform,
                        num_bins=257,
                        input_size=257,
                        input_proj=512,
                        num_layers=2,
                        hidden=512,
                        num_spks=num_spks,
                        output_nonlinear=nonlinear)
    inp = th.rand(2, 64000)
    x = base_rnn(inp)
    if num_spks > 1:
        x = x[0]
    assert x.shape == th.Size([2, 257, 249])
    z = base_rnn.infer(inp[0])
    if num_spks > 1:
        z = z[0]
    assert z.shape == th.Size([64000])
Ejemplo n.º 13
0
Archivo: eval.py Proyecto: xmpx/aps
    def _load(self,
              cpt_dir: str,
              cpt_tag: str = "best",
              task: str = "asr") -> Tuple[int, nn.Module, Dict]:
        if task not in ["asr", "sse"]:
            raise ValueError(f"Unknown task name: {task}")
        cpt_dir = pathlib.Path(cpt_dir)
        # load checkpoint
        cpt = th.load(cpt_dir / f"{cpt_tag}.pt.tar", map_location="cpu")
        with open(cpt_dir / "train.yaml", "r") as f:
            conf = yaml.full_load(f)
            if task == "asr":
                net_cls = aps_asr_nnet(conf["nnet"])
            else:
                net_cls = aps_sse_nnet(conf["nnet"])
        asr_transform = None
        enh_transform = None
        self.accept_raw = False
        if "asr_transform" in conf:
            asr_transform = aps_transform("asr")(**conf["asr_transform"])
            # if no STFT layer
            self.accept_raw = asr_transform.spectra_index != -1
        if "enh_transform" in conf:
            enh_transform = aps_transform("enh")(**conf["enh_transform"])
            self.accept_raw = True
        if enh_transform and asr_transform:
            nnet = net_cls(enh_transform=enh_transform,
                           asr_transform=asr_transform,
                           **conf["nnet_conf"])
        elif asr_transform:
            nnet = net_cls(asr_transform=asr_transform, **conf["nnet_conf"])
        elif enh_transform:
            nnet = net_cls(enh_transform=enh_transform, **conf["nnet_conf"])
        else:
            nnet = net_cls(**conf["nnet_conf"])

        nnet.load_state_dict(cpt["model_state"])
        return cpt["epoch"], nnet, conf
Ejemplo n.º 14
0
def test_dcunet(num_branch, cplx):
    nnet_cls = aps_sse_nnet("sse@dcunet")
    transform = EnhTransform(feats="", frame_len=512, frame_hop=256)
    dcunet = nnet_cls(enh_transform=transform,
                      K="7,5;7,5;5,3;5,3;3,3;3,3",
                      S="2,1;2,1;2,1;2,1;2,1;2,1",
                      C="32,32,64,64,64,128",
                      P="1,1,1,1,1,0",
                      O="0,0,1,1,1,0",
                      num_branch=num_branch,
                      cplx=cplx,
                      causal_conv=False,
                      freq_padding=True,
                      connection="cat")
    inp = th.rand(4, 64000)
    x = dcunet(inp)
    if num_branch > 1:
        x = x[0]
    assert x.shape == th.Size([4, 64000])
    y = dcunet.infer(inp[1])
    if num_branch > 1:
        y = y[0]
    assert y.shape == th.Size([64000])
Ejemplo n.º 15
0
def test_dccrn(num_spks, cplx):
    nnet_cls = aps_sse_nnet("sse@dccrn")
    transform = EnhTransform(feats="spectrogram", frame_len=512, frame_hop=256)
    dccrn = nnet_cls(enh_transform=transform,
                     cplx=cplx,
                     K="3,3;3,3;3,3;3,3;3,3;3,3;3,3",
                     S="2,1;2,1;2,1;2,1;2,1;2,1;2,1",
                     P="1,1,1,1,1,0,0",
                     O="0,0,0,0,0,0,1",
                     C="16,32,64,64,128,128,256",
                     num_spks=num_spks,
                     rnn_resize=512 if cplx else 256,
                     non_linear="sigmoid",
                     connection="cat")
    inp = th.rand(4, 64000)
    x = dccrn(inp)
    if num_spks > 1:
        x = x[0]
    assert x.shape == th.Size([4, 64000])
    y = dccrn.infer(inp[1])
    if num_spks > 1:
        y = y[0]
    assert y.shape == th.Size([64000])
Ejemplo n.º 16
0
def test_tasnet(num_spks, nonlinear):
    nnet_cls = aps_sse_nnet("sse@time_tasnet")
    tasnet = nnet_cls(L=40,
                      N=256,
                      X=8,
                      R=4,
                      B=256,
                      H=512,
                      P=3,
                      input_norm="gLN",
                      norm="BN",
                      num_spks=num_spks,
                      non_linear=nonlinear,
                      block_residual=True,
                      causal=False)
    inp = th.rand(4, 64000)
    x = tasnet(inp)
    if num_spks > 1:
        x = x[0]
    assert x.shape == th.Size([4, 64000])
    y = tasnet.infer(inp[1])
    if num_spks > 1:
        y = y[0]
    assert y.shape == th.Size([64000])