Ejemplo n.º 1
0
def test_demask():
    model = DeMask()
    test_input = torch.randn(1, 801)

    model_conf = model.serialize()
    reconstructed_model = DeMask.from_pretrained(model_conf)
    assert_allclose(model(test_input), reconstructed_model(test_input))
Ejemplo n.º 2
0
def test_sample_rate():
    demask = DeMask(hidden_dims=(16, ),
                    kernel_size=8,
                    n_filters=8,
                    stride=4,
                    sample_rate=9704)
    assert demask.sample_rate == 9704
Ejemplo n.º 3
0
def test_forward(input_type, output_type, fb_name, data):
    demask = DeMask(
        input_type=input_type,
        output_type=output_type,
        fb_name=fb_name,
        hidden_dims=(16, ),
        kernel_size=8,
        n_filters=8,
        stride=4,
    )
    demask = demask.eval()
    with torch.no_grad():
        demask(data)
Ejemplo n.º 4
0
def test_enhancement_model(small_model_params, test_data):
    params = small_model_params["DeMask"]
    filter_bank = "free"
    device = get_default_device()
    inputs = ((torch.rand(1, 201, device=device) - 0.5) * 2, )
    test_data = test_data.to(device)
    model = DeMask(**params, fb_type=filter_bank).eval().to(device)
    traced = torch.jit.trace(model, inputs)

    # check forward
    with torch.no_grad():
        ref = model(test_data)
        out = traced(test_data)
        assert_allclose(ref, out)
Ejemplo n.º 5
0
def test_get_model_args():
    demask = DeMask()
    expected = {
        "activation": "relu",
        "dropout": 0,
        "fb_name": "STFTFB",
        "hidden_dims": (1024, ),
        "input_type": "mag",
        "kernel_size": 512,
        "mask_act": "relu",
        "n_filters": 512,
        "norm_type": "gLN",
        "output_type": "mag",
        "sample_rate": 16000,
        "stride": 256,
    }
    assert demask.get_model_args() == expected