Пример #1
0
def test_tensorflow_without_any_model():
    model_artifacts = ["test.blah"]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    with pytest.raises(RuntimeError) as err:
        loader.load_model()
    assert 'InputConfiguration: No valid TensorFlow model found in input files.' in str(
        err)
Пример #2
0
def test_tensorflow_multiple_pb_file():
    model_artifacts = ["test.pb", "test.pbtxt"]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    with pytest.raises(RuntimeError) as err:
        loader.load_model()
    assert 'InputConfiguration: Exactly one .pb or .pbtxt file is allowed for TensorFlow models.' in str(
        err)
Пример #3
0
def test_tensorflow_op_error(patch_tf_model_helper, patch_tf_parser,
                             patch_relay, patch_op_error):
    patch_relay.frontend.from_tensorflow.side_effect = patch_op_error(
        "Dummy OpError")
    model_artifacts = ["test.pb"]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    with pytest.raises(patch_op_error) as err:
        loader.load_model()
    assert 'Dummy OpError' in str(err)
Пример #4
0
def test_tensorflow_tf_model_helper_exception(patch_tf_model_helper):
    patch_tf_model_helper.return_value.extract_input_and_output_tensors.side_effect = Exception(
        "Dummy Exception")
    model_artifacts = ["test.pb"]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    with pytest.raises(RuntimeError) as err:
        loader.load_model()
    assert 'InputConfiguration: Framework cannot load Tensorflow model' in str(
        err)
Пример #5
0
def test_tensorflow_with_pb_file(patch_tvm, patch_relay, patch_tf_model_helper,
                                 patch_tf_parser, patch_op_error):
    patch_relay.frontend.from_tensorflow.return_value.__iter__.return_value = [
        "module", "params"
    ]
    model_artifacts = ["test.pb"]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    loader.load_model()
    patch_tf_parser.return_value.parse.assert_called()
    patch_relay.frontend.from_tensorflow.assert_called()
Пример #6
0
def test_tensorflow_relay_exception(patch_tf_model_helper, patch_tf_parser,
                                    patch_relay, patch_op_error):
    patch_relay.frontend.from_tensorflow.side_effect = Exception(
        "Dummy Exception")
    model_artifacts = ["test.pb"]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    with pytest.raises(RuntimeError) as err:
        loader.load_model()
    assert 'InputConfiguration: TVM cannot convert Tensorflow model' in str(
        err)
Пример #7
0
def test_tensorflow_multiple_saved_model_directory_file():
    model_dir1 = Path(tempfile.mkdtemp())
    model_dir1.joinpath("variables").mkdir(exist_ok=True)
    model_dir2 = Path(tempfile.mkdtemp())
    model_dir2.joinpath("variables").mkdir(exist_ok=True)
    model_artifacts = [model_dir1.as_posix(), model_dir2.as_posix()]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    with pytest.raises(RuntimeError) as err:
        loader.load_model()
    assert 'InputConfiguration: Exactly one saved model is allowed for TensorFlow models.' in str(
        err)
Пример #8
0
def test_tensorflow_with_model_dir(patch_tvm, patch_relay,
                                   patch_tf_model_helper, patch_tf_parser,
                                   patch_op_error):
    patch_relay.frontend.from_tensorflow.return_value.__iter__.return_value = [
        "module", "params"
    ]
    model_dir = Path(tempfile.mkdtemp())
    model_dir.joinpath("variables").mkdir(exist_ok=True)
    model_artifacts = [model_dir.as_posix()]
    data_shape = {"input": [1, 3, 224, 224]}
    loader = TensorflowModelLoader(model_artifacts, data_shape)
    loader.load_model()
    patch_tf_parser.return_value.parse.assert_called()
    patch_relay.frontend.from_tensorflow.assert_called()