Example #1
0
def test_wrong_model_name(mock_data, mock_model, use_cluster):
    """User requests to run a model that is not there"""

    data = mock_data.create_data(1)

    model = mock_model.create_torch_cnn()
    c = Client(None, use_cluster)
    c.set_model("simple_cnn", model, "TORCH", "CPU")
    c.put_tensor("input", data[0])
    with pytest.raises(RedisReplyError):
        c.run_model("wrong_cnn", ["input"], ["output"])
Example #2
0
def test_torch_inference(mock_model, use_cluster):
    # get model and set into database
    model = mock_model.create_torch_cnn()
    c = Client(None, use_cluster)
    c.set_model("torch_cnn", model, "TORCH")

    # setup input tensor
    data = torch.rand(1, 1, 3, 3).numpy()
    c.put_tensor("torch_cnn_input", data)

    # run model and get output
    c.run_model("torch_cnn",
                inputs=["torch_cnn_input"],
                outputs=["torch_cnn_output"])
    out_data = c.get_tensor("torch_cnn_output")
    assert out_data.shape == (1, 1, 1, 1)
Example #3
0
    buffer = io.BytesIO()
    torch.jit.save(module, buffer)
    str_model = buffer.getvalue()
    return str_model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="SmartRedis ensemble producer process.")
    parser.add_argument("--exchange", action="store_true")
    args = parser.parse_args()

    # get model and set into database
    model = create_torch_cnn()
    c = Client(False)
    c.set_model("torch_cnn", model, "TORCH")

    keyout = os.getenv("SSKEYOUT")
    keyin = os.getenv("SSKEYIN")

    assert keyout in ["producer_0", "producer_1"]

    if keyout == "producer_0":
        c.set_data_source("producer_1" if args.exchange else "producer_0")
        data = torch.ones(1, 1, 3, 3).numpy()
        data_other = -torch.ones(1, 1, 3, 3).numpy()
    elif keyout == "producer_1":
        c.set_data_source("producer_0" if args.exchange else "producer_1")
        data = -torch.ones(1, 1, 3, 3).numpy()
        data_other = torch.ones(1, 1, 3, 3).numpy()
Example #4
0
def test_set_model(mock_model, use_cluster):
    model = mock_model.create_torch_cnn()
    c = Client(None, use_cluster)
    c.set_model("simple_cnn", model, "TORCH", "CPU")
    returned_model = c.get_model("simple_cnn")
    assert model == returned_model