def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" if not has_fx_feature_extraction: pytest.skip( "Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required." ) input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() model = torch.jit.script(_create_fx_model(model)) with torch.no_grad(): outputs = tuple( model(torch.randn((batch_size, *input_size))).values()) if isinstance(outputs, tuple): outputs = torch.cat(outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs'
def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs'
def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs'