def test_wrong_model_name(mock_data, mock_model, use_cluster): """User requests to run a model that is not there""" data = mock_data.create_data(1) model = mock_model.create_torch_cnn() c = Client(None, use_cluster) c.set_model("simple_cnn", model, "TORCH", "CPU") c.put_tensor("input", data[0]) with pytest.raises(RedisReplyError): c.run_model("wrong_cnn", ["input"], ["output"])
def test_torch_inference(mock_model, use_cluster): # get model and set into database model = mock_model.create_torch_cnn() c = Client(None, use_cluster) c.set_model("torch_cnn", model, "TORCH") # setup input tensor data = torch.rand(1, 1, 3, 3).numpy() c.put_tensor("torch_cnn_input", data) # run model and get output c.run_model("torch_cnn", inputs=["torch_cnn_input"], outputs=["torch_cnn_output"]) out_data = c.get_tensor("torch_cnn_output") assert out_data.shape == (1, 1, 1, 1)
buffer = io.BytesIO() torch.jit.save(module, buffer) str_model = buffer.getvalue() return str_model if __name__ == "__main__": parser = argparse.ArgumentParser( description="SmartRedis ensemble producer process.") parser.add_argument("--exchange", action="store_true") args = parser.parse_args() # get model and set into database model = create_torch_cnn() c = Client(False) c.set_model("torch_cnn", model, "TORCH") keyout = os.getenv("SSKEYOUT") keyin = os.getenv("SSKEYIN") assert keyout in ["producer_0", "producer_1"] if keyout == "producer_0": c.set_data_source("producer_1" if args.exchange else "producer_0") data = torch.ones(1, 1, 3, 3).numpy() data_other = -torch.ones(1, 1, 3, 3).numpy() elif keyout == "producer_1": c.set_data_source("producer_0" if args.exchange else "producer_1") data = -torch.ones(1, 1, 3, 3).numpy() data_other = torch.ones(1, 1, 3, 3).numpy()
def test_set_model(mock_model, use_cluster): model = mock_model.create_torch_cnn() c = Client(None, use_cluster) c.set_model("simple_cnn", model, "TORCH", "CPU") returned_model = c.get_model("simple_cnn") assert model == returned_model