예제 #1
0
    def test_model_forward_fx_torchscript(model_name, batch_size):
        """Symbolically trace each model, script it, and run single forward pass"""
        if not has_fx_feature_extraction:
            pytest.skip(
                "Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required."
            )

        input_size = _get_input_size(model_name=model_name,
                                     target=TARGET_JIT_SIZE)
        if max(input_size) > MAX_JIT_SIZE:
            pytest.skip("Fixed input size model > limit.")

        with set_scriptable(True):
            model = create_model(model_name, pretrained=False)
        model.eval()

        model = torch.jit.script(_create_fx_model(model))
        with torch.no_grad():
            outputs = tuple(
                model(torch.randn((batch_size, *input_size))).values())
            if isinstance(outputs, tuple):
                outputs = torch.cat(outputs)

        assert outputs.shape[0] == batch_size
        assert not torch.isnan(outputs).any(), 'Output included NaNs'
예제 #2
0
def test_model_forward_torchscript(model_name, batch_size):
    """Run a single forward pass with each model"""
    with set_scriptable(True):
        model = create_model(model_name, pretrained=False)
    model.eval()
    input_size = (3, 128, 128)  # jit compile is already a bit slow and we've tested normal res already...
    model = torch.jit.script(model)
    outputs = model(torch.randn((batch_size, *input_size)))

    assert outputs.shape[0] == batch_size
    assert not torch.isnan(outputs).any(), 'Output included NaNs'
예제 #3
0
def test_model_forward_torchscript(model_name, batch_size):
    """Run a single forward pass with each model"""
    input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
    if max(input_size) > MAX_JIT_SIZE:
        pytest.skip("Fixed input size model > limit.")

    with set_scriptable(True):
        model = create_model(model_name, pretrained=False)
    model.eval()

    model = torch.jit.script(model)
    outputs = model(torch.randn((batch_size, *input_size)))

    assert outputs.shape[0] == batch_size
    assert not torch.isnan(outputs).any(), 'Output included NaNs'