def test_custom_input_registry(self): config = {"input_config": {}} ioctx = IOContext(self.test_dir, config, 0, None) class CustomInputReader(InputReader): def __init__(self, ioctx: IOContext): self.ioctx = ioctx def next(self): return 0 def input_creator(ioctx: IOContext): return ShuffledInput(CustomInputReader(ioctx)) register_input("custom_input", input_creator) self.assertTrue(registry_contains_input("custom_input")) creator = registry_get_input("custom_input") self.assertIsNotNone(creator) reader = creator(ioctx) self.assertIsInstance(reader, ShuffledInput) self.assertEqual(reader.next(), 0) self.assertEqual(ioctx.log_dir, self.test_dir) self.assertEqual(ioctx.config, config) self.assertEqual(ioctx.worker_index, 0) self.assertIsNone(ioctx.worker) self.assertEqual(ioctx.input_config, {})
def test_custom_input_procedure(self): class CustomJsonReader(JsonReader): def __init__(self, ioctx: IOContext): super().__init__(ioctx.input_config["input_files"], ioctx) def input_creator(ioctx: IOContext) -> InputReader: return ShuffledInput(CustomJsonReader(ioctx)) register_input("custom_input", input_creator) test_input_procedure = [ "custom_input", input_creator, "ray.rllib.examples.custom_input_api.CustomJsonReader", ] for input_procedure in test_input_procedure: for fw in framework_iterator(frameworks=("torch", "tf")): self.write_outputs(self.test_dir, fw) agent = PGTrainer(env="CartPole-v0", config={ "input": input_procedure, "input_config": { "input_files": self.test_dir + fw }, "input_evaluation": [], "framework": fw, }) result = agent.train() self.assertEqual(result["timesteps_total"], 250) self.assertTrue(np.isnan(result["episode_reward_mean"]))
Returns: instance of ShuffledInput to work with some offline rl algorithms """ return ShuffledInput(CustomJsonReader(ioctx)) if __name__ == "__main__": ray.init() args = parser.parse_args() # make absolute path because relative path looks in result directory args.input_files = os.path.abspath(args.input_files) # we register our custom input creator with this convenient function register_input("custom_input", input_creator) # config modified from rllib/tuned_examples/cql/pendulum-cql.yaml config = { "env": "Pendulum-v1", # we can either use the tune registry, class path, or direct function # to connect our input api. "input": "custom_input", # "input": "ray.rllib.examples.custom_input_api.CustomJsonReader", # "input": input_creator, # this gets passed to the IOContext "input_config": { "input_files": args.input_files, }, "framework": args.framework, "actions_in_input_normalized": True,
def register_minerl_input(): register_input("minerl", minerl_input_creator)