def test_demask(): model = DeMask() test_input = torch.randn(1, 801) model_conf = model.serialize() reconstructed_model = DeMask.from_pretrained(model_conf) assert_allclose(model(test_input), reconstructed_model(test_input))
def test_forward(input_type, output_type, fb_name, data): demask = DeMask( input_type=input_type, output_type=output_type, fb_name=fb_name, hidden_dims=(16, ), kernel_size=8, n_filters=8, stride=4, ) demask = demask.eval() with torch.no_grad(): demask(data)
def test_sample_rate(): demask = DeMask(hidden_dims=(16, ), kernel_size=8, n_filters=8, stride=4, sample_rate=9704) assert demask.sample_rate == 9704
def test_get_model_args(): demask = DeMask() expected = { "activation": "relu", "dropout": 0, "fb_name": "STFTFB", "hidden_dims": (1024, ), "input_type": "mag", "kernel_size": 512, "mask_act": "relu", "n_filters": 512, "norm_type": "gLN", "output_type": "mag", "sample_rate": 16000, "stride": 256, } assert demask.get_model_args() == expected
def test_enhancement_model(small_model_params, test_data): params = small_model_params["DeMask"] filter_bank = "free" device = get_default_device() inputs = ((torch.rand(1, 201, device=device) - 0.5) * 2, ) test_data = test_data.to(device) model = DeMask(**params, fb_type=filter_bank).eval().to(device) traced = torch.jit.trace(model, inputs) # check forward with torch.no_grad(): ref = model(test_data) out = traced(test_data) assert_allclose(ref, out)