Example #1
0
    def test_custom_input_registry(self):
        config = {"input_config": {}}
        ioctx = IOContext(self.test_dir, config, 0, None)

        class CustomInputReader(InputReader):
            def __init__(self, ioctx: IOContext):
                self.ioctx = ioctx

            def next(self):
                return 0

        def input_creator(ioctx: IOContext):
            return ShuffledInput(CustomInputReader(ioctx))

        register_input("custom_input", input_creator)
        self.assertTrue(registry_contains_input("custom_input"))
        creator = registry_get_input("custom_input")
        self.assertIsNotNone(creator)
        reader = creator(ioctx)
        self.assertIsInstance(reader, ShuffledInput)
        self.assertEqual(reader.next(), 0)
        self.assertEqual(ioctx.log_dir, self.test_dir)
        self.assertEqual(ioctx.config, config)
        self.assertEqual(ioctx.worker_index, 0)
        self.assertIsNone(ioctx.worker)
        self.assertEqual(ioctx.input_config, {})
Example #2
0
    def test_custom_input_procedure(self):
        class CustomJsonReader(JsonReader):
            def __init__(self, ioctx: IOContext):
                super().__init__(ioctx.input_config["input_files"], ioctx)

        def input_creator(ioctx: IOContext) -> InputReader:
            return ShuffledInput(CustomJsonReader(ioctx))

        register_input("custom_input", input_creator)
        test_input_procedure = [
            "custom_input",
            input_creator,
            "ray.rllib.examples.custom_input_api.CustomJsonReader",
        ]
        for input_procedure in test_input_procedure:
            for fw in framework_iterator(frameworks=("torch", "tf")):
                self.write_outputs(self.test_dir, fw)
                agent = PGTrainer(env="CartPole-v0",
                                  config={
                                      "input": input_procedure,
                                      "input_config": {
                                          "input_files": self.test_dir + fw
                                      },
                                      "input_evaluation": [],
                                      "framework": fw,
                                  })
                result = agent.train()
                self.assertEqual(result["timesteps_total"], 250)
                self.assertTrue(np.isnan(result["episode_reward_mean"]))
Example #3
0
    Returns:
        instance of ShuffledInput to work with some offline rl algorithms
    """
    return ShuffledInput(CustomJsonReader(ioctx))


if __name__ == "__main__":
    ray.init()
    args = parser.parse_args()

    # make absolute path because relative path looks in result directory
    args.input_files = os.path.abspath(args.input_files)

    # we register our custom input creator with this convenient function
    register_input("custom_input", input_creator)

    # config modified from rllib/tuned_examples/cql/pendulum-cql.yaml
    config = {
        "env": "Pendulum-v1",
        # we can either use the tune registry, class path, or direct function
        # to connect our input api.
        "input": "custom_input",
        # "input": "ray.rllib.examples.custom_input_api.CustomJsonReader",
        # "input": input_creator,
        # this gets passed to the IOContext
        "input_config": {
            "input_files": args.input_files,
        },
        "framework": args.framework,
        "actions_in_input_normalized": True,
Example #4
0
def register_minerl_input():
    register_input("minerl", minerl_input_creator)