Exemple #1
0
def test_wrong_model_name(mock_data, mock_model, use_cluster):
    """User requests to run a model that is not there"""

    data = mock_data.create_data(1)

    model = mock_model.create_torch_cnn()
    c = Client(None, use_cluster)
    c.set_model("simple_cnn", model, "TORCH", "CPU")
    c.put_tensor("input", data[0])
    with pytest.raises(RedisReplyError):
        c.run_model("wrong_cnn", ["input"], ["output"])
Exemple #2
0
def test_wrong_model_name_from_file(mock_data, mock_model, use_cluster):
    """User requests to run a model that is not there
       that was loaded from file."""

    try:
        data = mock_data.create_data(1)
        mock_model.create_torch_cnn(filepath="./torch_cnn.pt")
        c = Client(None, use_cluster)
        c.set_model_from_file("simple_cnn_from_file", "./torch_cnn.pt",
                              "TORCH", "CPU")
        c.put_tensor("input", data[0])
        with pytest.raises(RedisReplyError):
            c.run_model("wrong_cnn", ["input"], ["output"])
    finally:
        os.remove("torch_cnn.pt")
Exemple #3
0
def test_torch_inference(mock_model, use_cluster):
    # get model and set into database
    model = mock_model.create_torch_cnn()
    c = Client(None, use_cluster)
    c.set_model("torch_cnn", model, "TORCH")

    # setup input tensor
    data = torch.rand(1, 1, 3, 3).numpy()
    c.put_tensor("torch_cnn_input", data)

    # run model and get output
    c.run_model("torch_cnn",
                inputs=["torch_cnn_input"],
                outputs=["torch_cnn_output"])
    out_data = c.get_tensor("torch_cnn_output")
    assert out_data.shape == (1, 1, 1, 1)

# Connect a SmartRedis client
db_address = "127.0.0.1:6379"
client = Client(address=db_address, cluster=False)

try:
    net = Net()
    example_forward_input = torch.rand(1, 1, 3, 3)
    # Trace a module (implicitly traces `forward`) and construct a
    # `ScriptModule` with a single `forward` method
    module = torch.jit.trace(net, example_forward_input)

    # Save the traced model to a file
    torch.jit.save(module, "./torch_cnn.pt")

    # Set the model in the Redis database from the file
    client.set_model_from_file("file_cnn", "./torch_cnn.pt", "TORCH", "CPU")

    # Put a tensor in the database as a test input
    data = torch.rand(1, 1, 3, 3).numpy()
    client.put_tensor("torch_cnn_input", data)

    # Run model and retrieve the output
    client.run_model("file_cnn",
                     inputs=["torch_cnn_input"],
                     outputs=["torch_cnn_output"])
    out_data = client.get_tensor("torch_cnn_output")
finally:
    os.remove("torch_cnn.pt")
Exemple #5
0
        c.set_data_source("producer_1" if args.exchange else "producer_0")
        data = torch.ones(1, 1, 3, 3).numpy()
        data_other = -torch.ones(1, 1, 3, 3).numpy()
    elif keyout == "producer_1":
        c.set_data_source("producer_0" if args.exchange else "producer_1")
        data = -torch.ones(1, 1, 3, 3).numpy()
        data_other = torch.ones(1, 1, 3, 3).numpy()

    # setup input tensor
    c.put_tensor("torch_cnn_input", data)

    input_exists = c.poll_tensor("torch_cnn_input", 100, 100)
    assert input_exists

    other_input = c.get_tensor("torch_cnn_input")

    if args.exchange:
        assert np.all(other_input == data_other)
    else:
        assert np.all(other_input == data)

    # run model and get output
    c.run_model("torch_cnn",
                inputs=["torch_cnn_input"],
                outputs=["torch_cnn_output"])
    output_exists = c.poll_tensor("torch_cnn_output", 100, 100)
    assert output_exists

    out_data = c.get_tensor("torch_cnn_output")
    assert out_data.shape == (1, 1, 1, 1)