def test_wrong_model_name(mock_data, mock_model, use_cluster): """User requests to run a model that is not there""" data = mock_data.create_data(1) model = mock_model.create_torch_cnn() c = Client(None, use_cluster) c.set_model("simple_cnn", model, "TORCH", "CPU") c.put_tensor("input", data[0]) with pytest.raises(RedisReplyError): c.run_model("wrong_cnn", ["input"], ["output"])
def test_wrong_model_name_from_file(mock_data, mock_model, use_cluster): """User requests to run a model that is not there that was loaded from file.""" try: data = mock_data.create_data(1) mock_model.create_torch_cnn(filepath="./torch_cnn.pt") c = Client(None, use_cluster) c.set_model_from_file("simple_cnn_from_file", "./torch_cnn.pt", "TORCH", "CPU") c.put_tensor("input", data[0]) with pytest.raises(RedisReplyError): c.run_model("wrong_cnn", ["input"], ["output"]) finally: os.remove("torch_cnn.pt")
def test_torch_inference(mock_model, use_cluster): # get model and set into database model = mock_model.create_torch_cnn() c = Client(None, use_cluster) c.set_model("torch_cnn", model, "TORCH") # setup input tensor data = torch.rand(1, 1, 3, 3).numpy() c.put_tensor("torch_cnn_input", data) # run model and get output c.run_model("torch_cnn", inputs=["torch_cnn_input"], outputs=["torch_cnn_output"]) out_data = c.get_tensor("torch_cnn_output") assert out_data.shape == (1, 1, 1, 1)
# Connect a SmartRedis client db_address = "127.0.0.1:6379" client = Client(address=db_address, cluster=False) try: net = Net() example_forward_input = torch.rand(1, 1, 3, 3) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(net, example_forward_input) # Save the traced model to a file torch.jit.save(module, "./torch_cnn.pt") # Set the model in the Redis database from the file client.set_model_from_file("file_cnn", "./torch_cnn.pt", "TORCH", "CPU") # Put a tensor in the database as a test input data = torch.rand(1, 1, 3, 3).numpy() client.put_tensor("torch_cnn_input", data) # Run model and retrieve the output client.run_model("file_cnn", inputs=["torch_cnn_input"], outputs=["torch_cnn_output"]) out_data = client.get_tensor("torch_cnn_output") finally: os.remove("torch_cnn.pt")
c.set_data_source("producer_1" if args.exchange else "producer_0") data = torch.ones(1, 1, 3, 3).numpy() data_other = -torch.ones(1, 1, 3, 3).numpy() elif keyout == "producer_1": c.set_data_source("producer_0" if args.exchange else "producer_1") data = -torch.ones(1, 1, 3, 3).numpy() data_other = torch.ones(1, 1, 3, 3).numpy() # setup input tensor c.put_tensor("torch_cnn_input", data) input_exists = c.poll_tensor("torch_cnn_input", 100, 100) assert input_exists other_input = c.get_tensor("torch_cnn_input") if args.exchange: assert np.all(other_input == data_other) else: assert np.all(other_input == data) # run model and get output c.run_model("torch_cnn", inputs=["torch_cnn_input"], outputs=["torch_cnn_output"]) output_exists = c.poll_tensor("torch_cnn_output", 100, 100) assert output_exists out_data = c.get_tensor("torch_cnn_output") assert out_data.shape == (1, 1, 1, 1)