Ejemplo n.º 1
0
    def test_torch_repeated(self):
        ModelCatalog.register_custom_model("r1", TorchRepeatedSpyModel)
        register_env("repeat", lambda _: RepeatedSpaceEnv())
        a2c = A2CTrainer(
            env="repeat",
            config={
                "num_workers": 0,
                "rollout_fragment_length": 5,
                "train_batch_size": 5,
                "model": {
                    "custom_model": "r1",
                },
                "framework": "torch",
            },
        )

        # Skip first passes as they came from the TorchPolicy loss
        # initialization.
        TorchRepeatedSpyModel.capture_index = 0
        a2c.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "torch_rspy_in_{}".format(i)
                )
            )

            # Only look at the last entry (-1) in `seen` as we reset (re-use)
            # the ray-kv indices before training.
            self.assertEqual(to_list(seen[:][-1]), to_list(REPEATED_SAMPLES[i]))
Ejemplo n.º 2
0
 def test_a2c_exec_impl(ray_start_regular):
     trainer = A2CTrainer(env="CartPole-v0",
                          config={
                              "min_iter_time_s": 0,
                          })
     assert isinstance(trainer.train(), dict)
     check_compute_action(trainer)
Ejemplo n.º 3
0
    def test_py_torch_model(self):
        ModelCatalog.register_custom_model("composite", TorchSpyModel)
        register_env("nested", lambda _: NestedDictEnv())
        a2c = A2CTrainer(env="nested",
                         config={
                             "num_workers": 0,
                             "rollout_fragment_length": 5,
                             "train_batch_size": 5,
                             "model": {
                                 "custom_model": "composite",
                             },
                             "framework": "torch",
                         })

        a2c.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "torch_spy_in_{}".format(i)))
            pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
            cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
            task_i = one_hot(
                DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            self.assertEqual(seen[2][0].tolist(), task_i)
Ejemplo n.º 4
0
 def test_global_vars_update(self):
     # Allow for Unittest run.
     ray.init(num_cpus=5, ignore_reinit_error=True)
     for fw in framework_iterator(frameworks=("tf2", "tf")):
         agent = A2CTrainer(
             env="CartPole-v0",
             config={
                 "num_workers": 1,
                 # lr = 0.1 - [(0.1 - 0.000001) / 100000] * ts
                 "lr_schedule": [[0, 0.1], [100000, 0.000001]],
                 "framework": fw,
             })
         policy = agent.get_policy()
         for i in range(3):
             result = agent.train()
             print("{}={}".format(STEPS_TRAINED_COUNTER,
                                  result["info"][STEPS_TRAINED_COUNTER]))
             print("{}={}".format(STEPS_SAMPLED_COUNTER,
                                  result["info"][STEPS_SAMPLED_COUNTER]))
             global_timesteps = policy.global_timestep
             print("global_timesteps={}".format(global_timesteps))
             expected_lr = \
                 0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
             lr = policy.cur_lr
             if fw == "tf":
                 lr = policy.get_session().run(lr)
             check(lr, expected_lr, rtol=0.05)
         agent.stop()
Ejemplo n.º 5
0
 def test_global_vars_update(self):
     for fw in framework_iterator(frameworks=("tf2", "tf")):
         agent = A2CTrainer(
             env="CartPole-v0",
             config={
                 "num_workers": 1,
                 # lr = 0.1 - [(0.1 - 0.000001) / 100000] * ts
                 "lr_schedule": [[0, 0.1], [100000, 0.000001]],
                 "framework": fw,
             },
         )
         policy = agent.get_policy()
         for i in range(3):
             result = agent.train()
             print("{}={}".format(NUM_AGENT_STEPS_TRAINED,
                                  result["info"][NUM_AGENT_STEPS_TRAINED]))
             print("{}={}".format(NUM_AGENT_STEPS_SAMPLED,
                                  result["info"][NUM_AGENT_STEPS_SAMPLED]))
             global_timesteps = (policy.global_timestep if fw == "tf" else
                                 policy.global_timestep.numpy())
             print("global_timesteps={}".format(global_timesteps))
             expected_lr = 0.1 - (
                 (0.1 - 0.000001) / 100000) * global_timesteps
             lr = policy.cur_lr
             if fw == "tf":
                 lr = policy.get_session().run(lr)
             check(lr, expected_lr, rtol=0.05)
         agent.stop()
Ejemplo n.º 6
0
 def test_a2c_exec_impl(ray_start_regular):
     trainer = A2CTrainer(env="CartPole-v0",
                          config={
                              "min_iter_time_s": 0,
                              "use_exec_api": True
                          })
     assert isinstance(trainer.train(), dict)
Ejemplo n.º 7
0
    def test_py_torch_model(self):
        ModelCatalog.register_custom_model("composite", TorchSpyModel)
        register_env("nested", lambda _: NestedDictEnv())
        a2c = A2CTrainer(env="nested",
                         config={
                             "num_workers": 0,
                             "rollout_fragment_length": 5,
                             "train_batch_size": 5,
                             "model": {
                                 "custom_model": "composite",
                             },
                             "framework": "torch",
                         })

        # Skip first passes as they came from the TorchPolicy loss
        # initialization.
        TorchSpyModel.capture_index = 0
        a2c.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "torch_spy_in_{}".format(i)))

            pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
            cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
            task_i = one_hot(
                DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
            # Only look at the last entry (-1) in `seen` as we reset (re-use)
            # the ray-kv indices before training.
            self.assertEqual(seen[0][-1].tolist(), pos_i)
            self.assertEqual(seen[1][-1].tolist(), cam_i)
            check(seen[2][-1], task_i)
Ejemplo n.º 8
0
 def test_global_vars_update(self):
     # Allow for Unittest run.
     ray.init(num_cpus=5, ignore_reinit_error=True)
     for fw in framework_iterator(frameworks=()):
         agent = A2CTrainer(
             env="CartPole-v0",
             config={
                 "num_workers": 1,
                 "lr_schedule": [[0, 0.1], [100000, 0.000001]],
                 "framework": fw,
             })
         result = agent.train()
         for i in range(10):
             result = agent.train()
             print("num_steps_sampled={}".format(
                 result["info"]["num_steps_sampled"]))
             print("num_steps_trained={}".format(
                 result["info"]["num_steps_trained"]))
             print("num_steps_sampled={}".format(
                 result["info"]["num_steps_sampled"]))
             print("num_steps_trained={}".format(
                 result["info"]["num_steps_trained"]))
             if i == 0:
                 self.assertGreater(
                     result["info"]["learner"]["default_policy"]["cur_lr"],
                     0.01)
             if result["info"]["learner"]["default_policy"]["cur_lr"] < \
                     0.07:
                 break
         self.assertLess(
             result["info"]["learner"]["default_policy"]["cur_lr"], 0.07)
         agent.stop()
Ejemplo n.º 9
0
 def test_a2c_exec_impl_microbatch(ray_start_regular):
     trainer = A2CTrainer(env="CartPole-v0",
                          config={
                              "min_iter_time_s": 0,
                              "microbatch_size": 10,
                              "use_exec_api": True,
                          })
     assert isinstance(trainer.train(), dict)
     check_compute_action(trainer)
Ejemplo n.º 10
0
 def test_a2c_pipeline_microbatch(ray_start_regular):
     trainer = A2CTrainer(
         env="CartPole-v0",
         config={
             "min_iter_time_s": 0,
             "microbatch_size": 10,
             "use_pipeline_impl": True,
         })
     assert isinstance(trainer.train(), dict)
Ejemplo n.º 11
0
 def testGlobalVarsUpdate(self):
     agent = A2CTrainer(env="CartPole-v0",
                        config={
                            "lr_schedule": [[0, 0.1], [400, 0.000001]],
                        })
     result = agent.train()
     self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
     result2 = agent.train()
     self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
Ejemplo n.º 12
0
    def test_exec_plan_save_restore(ray_start_regular):
        trainer = A2CTrainer(env="CartPole-v0",
                             config={
                                 "min_iter_time_s": 0,
                             })
        res1 = trainer.train()
        checkpoint = trainer.save()
        for _ in range(2):
            res2 = trainer.train()
        assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2)
        trainer.restore(checkpoint)

        # Should restore the timesteps counter to the same as res2.
        res3 = trainer.train()
        assert res3["timesteps_total"] < res2["timesteps_total"], (res2, res3)
Ejemplo n.º 13
0
    def test_pipeline_save_restore(ray_start_regular):
        trainer = A2CTrainer(env="CartPole-v0",
                             config={
                                 "min_iter_time_s": 0,
                                 "use_pipeline_impl": True
                             })
        res1 = trainer.train()
        checkpoint = trainer.save()
        res2 = trainer.train()
        assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2)
        trainer.restore(checkpoint)

        # Should restore the timesteps counter to the same as res2.
        res3 = trainer.train()
        assert res3["timesteps_total"] == res2["timesteps_total"], (res2, res3)
Ejemplo n.º 14
0
 def test_exec_plan_stats(ray_start_regular):
     trainer = A2CTrainer(env="CartPole-v0",
                          config={
                              "min_iter_time_s": 0,
                          })
     result = trainer.train()
     assert isinstance(result, dict)
     assert "info" in result
     assert "learner" in result["info"]
     assert "num_steps_sampled" in result["info"]
     assert "num_steps_trained" in result["info"]
     assert "timers" in result
     assert "learn_time_ms" in result["timers"]
     assert "learn_throughput" in result["timers"]
     assert "sample_time_ms" in result["timers"]
     assert "sample_throughput" in result["timers"]
     assert "update_time_ms" in result["timers"]
Ejemplo n.º 15
0
 def test_global_vars_update(self):
     ray.init(num_cpus=5, ignore_reinit_error=True)
     agent = A2CTrainer(env="CartPole-v0",
                        config={
                            "lr_schedule": [[0, 0.1], [400, 0.000001]],
                        })
     result = agent.train()
     self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
     result2 = agent.train()
     print("num_steps_sampled={}".format(
         result["info"]["num_steps_sampled"]))
     print("num_steps_trained={}".format(
         result["info"]["num_steps_trained"]))
     self.assertLess(result2["info"]["learner"]["cur_lr"], 0.09)
     print("num_steps_sampled={}".format(
         result["info"]["num_steps_sampled"]))
     print("num_steps_trained={}".format(
         result["info"]["num_steps_trained"]))
Ejemplo n.º 16
0
 def test_exec_plan_stats(ray_start_regular):
     for fw in framework_iterator(frameworks=("torch", "tf")):
         trainer = A2CTrainer(env="CartPole-v0",
                              config={
                                  "min_time_s_per_reporting": 0,
                                  "framework": fw,
                              })
         result = trainer.train()
         assert isinstance(result, dict)
         assert "info" in result
         assert LEARNER_INFO in result["info"]
         assert STEPS_SAMPLED_COUNTER in result["info"]
         assert STEPS_TRAINED_COUNTER in result["info"]
         assert "timers" in result
         assert "learn_time_ms" in result["timers"]
         assert "learn_throughput" in result["timers"]
         assert "sample_time_ms" in result["timers"]
         assert "sample_throughput" in result["timers"]
         assert "update_time_ms" in result["timers"]
Ejemplo n.º 17
0
    def test_exec_plan_save_restore(ray_start_regular):
        for fw in framework_iterator(frameworks=("torch", "tf")):
            trainer = A2CTrainer(env="CartPole-v0",
                                 config={
                                     "min_time_s_per_reporting": 0,
                                     "framework": fw,
                                 })
            res1 = trainer.train()
            checkpoint = trainer.save()
            for _ in range(2):
                res2 = trainer.train()
            assert res2["timesteps_total"] > res1["timesteps_total"], \
                (res1, res2)
            trainer.restore(checkpoint)

            # Should restore the timesteps counter to the same as res2.
            res3 = trainer.train()
            assert res3["timesteps_total"] < res2["timesteps_total"], \
                (res2, res3)
Ejemplo n.º 18
0
def test_dependency_torch():
    # Do not import torch for testing purposes.
    os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"

    from ray.rllib.agents.a3c import A2CTrainer
    assert "torch" not in sys.modules, \
        "Torch initially present, when it shouldn't."

    # note: no ray.init(), to test it works without Ray
    trainer = A2CTrainer(env="CartPole-v0",
                         config={
                             "framework": "tf",
                             "num_workers": 0
                         })
    trainer.train()

    assert "torch" not in sys.modules, "Torch should not be imported"

    # Clean up.
    del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]

    print("ok")
Ejemplo n.º 19
0
    def test_torch_repeated(self):
        ModelCatalog.register_custom_model("r1", TorchRepeatedSpyModel)
        register_env("repeat", lambda _: RepeatedSpaceEnv())
        a2c = A2CTrainer(env="repeat",
                         config={
                             "num_workers": 0,
                             "rollout_fragment_length": 5,
                             "train_batch_size": 5,
                             "model": {
                                 "custom_model": "r1",
                             },
                             "framework": "torch",
                         })

        a2c.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "torch_rspy_in_{}".format(i)))
            self.assertEqual(to_list(seen), [to_list(REPEATED_SAMPLES[i])])
Ejemplo n.º 20
0
    os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1"

    # Test registering (includes importing) all Trainers.
    from ray.rllib import _register_all

    # This should surface any dependency on tf, e.g. inside function
    # signatures/typehints.
    _register_all()

    from ray.rllib.agents.a3c import A2CTrainer

    assert ("tensorflow" not in sys.modules
            ), "`tensorflow` initially present, when it shouldn't!"

    # Note: No ray.init(), to test it works without Ray
    trainer = A2CTrainer(env="CartPole-v0",
                         config={
                             "framework": "torch",
                             "num_workers": 0
                         })
    trainer.train()

    assert (
        "tensorflow" not in sys.modules
    ), "`tensorflow` should not be imported after creating and training A3CTrainer!"

    # Clean up.
    del os.environ["RLLIB_TEST_NO_TF_IMPORT"]

    print("ok")
Ejemplo n.º 21
0
#!/usr/bin/env python

import os
import sys

if __name__ == "__main__":
    # Do not import torch for testing purposes.
    os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"

    from ray.rllib.agents.a3c import A2CTrainer
    assert "torch" not in sys.modules, \
        "Torch initially present, when it shouldn't."

    # note: no ray.init(), to test it works without Ray
    trainer = A2CTrainer(env="CartPole-v0",
                         config={
                             "use_pytorch": False,
                             "num_workers": 0
                         })
    trainer.train()

    assert "torch" not in sys.modules, "Torch should not be imported"
Ejemplo n.º 22
0
    args = getArgs()

    ray.init(num_gpus=1)
    register_env("custom-explorer", env_creator)

    config = DEFAULT_CONFIG.copy()
    config['num_workers'] = args.workers
    config['num_gpus'] = 1
    config['framework'] = "torch"
    config['gamma'] = args.gamma

    config['model']['dim'] = 21
    config['model']['conv_filters'] = [[8, [3, 3], 2], [16, [2, 2], 2],
                                       [512, [6, 6], 1]]

    trainner = A2CTrainer(config=config, env="mars_explorer:explorer-v01")

    if PATH != "":
        print(f"\nLoading trainner from dir {PATH}")
        trainner.restore(PATH)
    else:
        print(f"Starting trainning without a priori knowledge")

    N_start = 0
    N_finish = args.steps
    results = []
    episode_data = []
    episode_json = []

    writer = SummaryWriter(comment="SAC-GEP")
Ejemplo n.º 23
0
    def execute(self):
        timesteps = 0
        best_period_value = None

        if self.pr.agent.name() == "A2C":
            trainer = A2CTrainer(config=self.rllib_config,
                                 logger_creator=rllib_logger_creator)
        elif self.pr.agent.name() == "PPO":
            trainer = PPOTrainer(config=self.rllib_config,
                                 logger_creator=rllib_logger_creator)
            # import pdb; pdb.set_trace()
        else:
            raise ValueError('There is no rllib trainer with name ' +
                             self.pr.agent.name())

        tf_writer = SummaryWriter(
            self.pr.save_logs_to) if self.pr.save_logs_to else None

        reward_metric = Metric(short_name='rews',
                               long_name='trajectory reward',
                               formatting_string='{:5.1f}',
                               higher_is_better=True)
        time_step_metric = Metric(short_name='steps',
                                  long_name='total number of steps',
                                  formatting_string='{:5.1f}',
                                  higher_is_better=True)

        metrics = [reward_metric, time_step_metric]

        if self.pr.train:
            start_time = time.time()
            policy_save_tag = 0
            while timesteps < self.pr.total_steps:

                result = trainer.train()

                timesteps = result["timesteps_total"]
                reward_metric.log(result['evaluation']['episode_reward_mean'])
                time_step_metric.log(result['evaluation']['episode_len_mean'])
                # import pdb; pdb.set_trace()
                # # Get a metric list from each environment.
                # if hasattr(trainer, "evaluation_workers"):
                #     metric_lists = sum(trainer.evaluation_workers.foreach_worker(lambda w: w.foreach_env(lambda e: e.metrics)), [])
                # else:
                #     metric_lists = sum(trainer.workers.foreach_worker(lambda w: w.foreach_env(lambda e: e.metrics)), [])

                # metrics = metric_lists[0]

                # # Aggregate metrics from all other environments.
                # for metric_list in metric_lists[1:]:
                #     for i, metric in enumerate(metric_list):
                #         metrics[i]._values.extend(metric._values)

                save_logs_to = self.pr.save_logs_to
                model_save_paths_dict = self.pr.model_save_paths_dict
                # Consider whether to save a model.
                saved = False
                if model_save_paths_dict is not None and metrics[
                        0].currently_optimal:
                    # trainer.get_policy().model.save(model_save_paths_dict)
                    policy_save_tag += 1
                    trainer.get_policy().model.save_model_in_progress(
                        model_save_paths_dict, policy_save_tag)
                    saved = True

                # Write the metrics for this reporting period.
                total_seconds = time.time() - start_time
                logger.write_and_condense_metrics(total_seconds, 'iters',
                                                  timesteps, saved, metrics,
                                                  tf_writer)

                # Clear the metrics, both those maintained by the training workers and by the evaluation ones.
                condense_fn = lambda environment: [
                    m.condense_values() for m in environment.metrics
                ]
                trainer.workers.foreach_worker(
                    lambda w: w.foreach_env(condense_fn))
                if hasattr(trainer, "evaluation_workers"):
                    trainer.evaluation_workers.foreach_worker(
                        lambda w: w.foreach_env(condense_fn))

        else:
            start_time = time.time()
            env = trainer.workers.local_worker().env
            metrics = env.metrics
            worker = trainer.workers.local_worker()
            steps = steps_since_report = 0

            while True:
                batch = worker.sample()
                current_steps = len(batch["obs"])
                steps += current_steps
                steps_since_report += current_steps

                if steps_since_report >= self.pr.reporting_interval:
                    total_seconds = time.time() - start_time

                    # Write the metrics for this reporting period.
                    logger.write_and_condense_metrics(total_seconds, 'iters',
                                                      steps, False, metrics,
                                                      tf_writer)

                    steps_since_report = 0
                    if steps >= self.pr.total_steps:
                        break

            env.close()

        # Get a summary metric for the entire stage, based on the environment's first metric.
        summary_metric = logger.summarize_stage(metrics[0])

        # Temporary workaround for https://github.com/ray-project/ray/issues/8205
        ray.shutdown()
        _register_all()

        return summary_metric
Ejemplo n.º 24
0
if __name__ == "__main__":
    # Do not import torch for testing purposes.
    os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"

    from ray.rllib.agents.a3c import A2CTrainer

    assert "torch" not in sys.modules, "`torch` initially present, when it shouldn't!"

    # Note: No ray.init(), to test it works without Ray
    trainer = A2CTrainer(
        env="CartPole-v0",
        config={
            "framework": "tf",
            "num_workers": 0,
            # Disable the logger due to a sort-import attempt of torch
            # inside the tensorboardX.SummaryWriter class.
            "logger_config": {
                "type": "ray.tune.logger.NoopLogger",
            },
        },
    )
    trainer.train()

    assert (
        "torch" not in sys.modules
    ), "`torch` should not be imported after creating and training A3CTrainer!"

    # Clean up.
    del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]

    print("ok")
Ejemplo n.º 25
0
 def __init__(self, env, env_config, config):
     self.config = config
     self.config['env_config'] = env_config
     self.env = env(env_config)
     self.agent = A2CTrainer(config=self.config, env=env)
    # Notice that trial_max will only work for stochastic policies
    register_env(
        "ic20env", lambda _: SimplifiedIC20Environment(obs_state_processor,
                                                       act_state_processor,
                                                       UnstableReward(),
                                                       trial_max=10))
    ten_gig = 10737418240

    trainer = A2CTrainer(
        env="ic20env",
        config=merge_dicts(
            DEFAULT_CONFIG,
            {
                # -- Specific parameters
                'num_gpus': 0,
                'num_workers': 15,
                "num_envs_per_worker": 1,
                "num_cpus_per_worker": 1,
                "memory_per_worker": ten_gig,
                'gamma': 0.99,
            }))

    # Attempt to restore from checkpoint if possible.
    if os.path.exists(CHECKPOINT_FILE):
        checkpoint_path = open(CHECKPOINT_FILE).read()
        print("Restoring from checkpoint path", checkpoint_path)
        trainer.restore(checkpoint_path)

    # Serving and training loop
    while True:
Ejemplo n.º 27
0
FUN_CONFIG = A2CTrainer.merge_trainer_configs(
    A2C_DEFAULT_CONFIG,
    {
        'use_gae': False,
        'lr': 1e-3,

        # Can be either constant, anneal, or cyclic
        'lr_mode': 'constant',

        # Linear learning rate annealing
        'end_lr': 1e-4,
        'anneal_timesteps': 10000000,

        # Cyclic learning rate
        'cyclic_lr_base_lr': 1e-4,
        'cyclic_lr_max_lr': 1e-3,
        'cyclic_lr_step_size': 200,
        'cyclic_lr_mode': 'triangular',
        'cyclic_lr_gamma': 0.99,
        'grad_clip': 0.5,
        'epsilon': 1e-8,
        'fun_horizon': 10,
        'model': {
            'custom_model_config': {
                'fun_horizon': 10
            }
        },
        '_use_trajectory_view_api': False,
    },
    _allow_unknown_configs=True,
)
Ejemplo n.º 28
0
def main() -> None:
    ray.init()
    np.random.seed(0)

    # instructions = {
    #     0: [Instruction(time=0, x=5, y=5)],
    #     1: [Instruction(time=1, x=5, y=5), Instruction(time=1, x=1, y=5)],
    #     2: [Instruction(time=2, x=5, y=5, rng=np.random.default_rng())],
    # }
    # task = Task(
    #     target_x=1,
    #     target_y=5,
    #     instructions=instructions,
    #     tot_frames=4,
    #     width=42,
    #     height=42,
    # )

    # task = ODR(target_x=1, target_y=5, width=42, height=42)
    # task = Gap(target_x=1, target_y=5, width=42, height=42)
    task = ODRDistract(target_x=1, target_y=5, width=42, height=42)

    def env_creator(env_config):
        return Environment(env_config)  # return an env instance

    register_env("my_env", env_creator)

    # trainer_config = DEFAULT_CONFIG.copy()
    # trainer_config["num_workers"] = 1
    # trainer_config["train_batch_size"] = 20  # 100
    # trainer_config["sgd_minibatch_size"] = 15  # 32
    # trainer_config["num_sgd_iter"] = 50

    trainer = PPOTrainer(
        env="my_env",
        config={
            "env_config": {"task": task},
            "framework": "torch",
            "num_workers": 1,
            "train_batch_size": 10,
            "sgd_minibatch_size": 5,
            "num_sgd_iter": 10,
            # "model": {
            #     # Whether to wrap the model with an LSTM.
            #     "use_lstm": True,
            #     # Max seq len for training the LSTM, defaults to 20.
            #     "max_seq_len": task.tot_frames - 1,
            #     # # Size of the LSTM cell.
            #     "lstm_cell_size": task.tot_frames - 1,
            #     # # Whether to feed a_{t-1}, r_{t-1} to LSTM.
            #     # # "lstm_use_prev_action_reward": False,
            # },
        },
    )

    trainer = A2CTrainer(
        env="my_env",
        config={
            "env_config": {"task": task},
            "framework": "torch",
            "num_workers": 1,
            "train_batch_size": 10,
            # "model": {
            #     # Whether to wrap the model with an LSTM.
            #     "use_lstm": True,
            #     # Max seq len for training the LSTM, defaults to 20.
            #     "max_seq_len": task.tot_frames - 1,
            #     # # Size of the LSTM cell.
            #     "lstm_cell_size": task.tot_frames - 1,
            #     # # Whether to feed a_{t-1}, r_{t-1} to LSTM.
            #     # # "lstm_use_prev_action_reward": False,
            # },
        },
    )

    # trainer = DQNTrainer(
    #     env="my_env",
    #     config={
    #         "env_config": {"task": task},
    #         "framework": "torch",
    #         "num_workers": 1,
    #         "train_batch_size": 10,
    #         # "model": {
    #         #     # Whether to wrap the model with an LSTM.
    #         #     "use_lstm": True,
    #         #     # Max seq len for training the LSTM, defaults to 20.
    #         #     "max_seq_len": task.tot_frames - 1,
    #         #     # # Size of the LSTM cell.
    #         #     "lstm_cell_size": task.tot_frames - 1,
    #         #     # # Whether to feed a_{t-1}, r_{t-1} to LSTM.
    #         #     # # "lstm_use_prev_action_reward": False,
    #         # },
    #     },
    # )

    env = Environment(env_config={"task": task})

    for i in range(200):
        print(f"Training iteration {i}...")
        trainer.train()

        done = False
        cumulative_reward = 0.0
        observation = env.reset()

        while not done:
            action = trainer.compute_action(observation)

            observation, reward, done, results = env.step(action)
            print(f"Time: {env.time}. Action: {action}")
            cumulative_reward += reward
        print(
            f"Last step reward: {reward: .3e}; Cumulative reward: {cumulative_reward:.3e}"
        )
Ejemplo n.º 29
0
            trainer = PPOTrainer(config=config_copy, env='Bertrand')
        elif trainer_choice == 'A3C':
            from ray.rllib.agents.a3c import A3CTrainer
            config['num_workers'] = 1
            # config['lr'] = 0.01
            # For eval afterward
            config_copy = config.copy()
            config_copy['explore'] = False
            trainer = A3CTrainer(config=config_copy, env='Bertrand')
        elif trainer_choice == 'A2C':
            from ray.rllib.agents.a3c import A2CTrainer
            config['num_workers'] = 1
            # For eval afterward
            config_copy = config.copy()
            config_copy['explore'] = False
            trainer = A2CTrainer(config=config_copy, env='Bertrand')
        elif trainer_choice == 'MADDPG':
            from ray.rllib.contrib.maddpg import MADDPGTrainer
            config['agent_id'] = 0
            # For eval afterward
            config_copy = config.copy()
            config_copy['explore'] = False
            trainer = MADDPGTrainer(config=config_copy, env='Bertrand')
        elif trainer_choice == 'DDPG':
            from ray.rllib.agents.ddpg import DDPGTrainer
            # For eval afterward
            config_copy = config.copy()
            config_copy['explore'] = False
            trainer = DDPGTrainer(config=config_copy, env='Bertrand')

        analysis = tune.run(