import argparse import os from smartredis import Client parser = argparse.ArgumentParser(description="SmartRedis ensemble consumer process.") parser.add_argument("--redis-port") args = parser.parse_args() # get model and set into database c = Client(address="127.0.0.1:"+str(args.redis_port), cluster=False) # Incoming entity prefixes are stored as a comma-separated list # in the env variable SSKEYIN keyin = os.getenv("SSKEYIN") data_sources = keyin.split(",") data_sources.sort() for key in data_sources: c.set_data_source(key) input_exists = c.poll_tensor("product", 100, 100) db_tensor = c.get_tensor("product") print(f"Tensor for {key} is:", db_tensor)
assert keyout in ["producer_0", "producer_1"] if keyout == "producer_0": c.set_data_source("producer_1" if args.exchange else "producer_0") data = torch.ones(1, 1, 3, 3).numpy() data_other = -torch.ones(1, 1, 3, 3).numpy() elif keyout == "producer_1": c.set_data_source("producer_0" if args.exchange else "producer_1") data = -torch.ones(1, 1, 3, 3).numpy() data_other = torch.ones(1, 1, 3, 3).numpy() # setup input tensor c.put_tensor("torch_cnn_input", data) input_exists = c.poll_tensor("torch_cnn_input", 100, 100) assert input_exists other_input = c.get_tensor("torch_cnn_input") if args.exchange: assert np.all(other_input == data_other) else: assert np.all(other_input == data) # run model and get output c.run_model("torch_cnn", inputs=["torch_cnn_input"], outputs=["torch_cnn_output"]) output_exists = c.poll_tensor("torch_cnn_output", 100, 100) assert output_exists