Exemplo n.º 1
0
    def do_test_nested_tuple(self, make_env):
        ModelCatalog.register_custom_model("composite2", TupleSpyModel)
        register_env("nested2", make_env)
        pg = PGTrainer(env="nested2",
                       config={
                           "num_workers": 0,
                           "rollout_fragment_length": 5,
                           "train_batch_size": 5,
                           "model": {
                               "custom_model": "composite2",
                           },
                           "framework": "tf",
                       })
        pg.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "t_spy_in_{}".format(i)))
            pos_i = TUPLE_SAMPLES[i][0].tolist()
            cam_i = TUPLE_SAMPLES[i][1][0].tolist()
            task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            self.assertEqual(seen[2][0].tolist(), task_i)
Exemplo n.º 2
0
    def test_custom_input_procedure(self):
        class CustomJsonReader(JsonReader):
            def __init__(self, ioctx: IOContext):
                super().__init__(ioctx.input_config["input_files"], ioctx)

        def input_creator(ioctx: IOContext) -> InputReader:
            return ShuffledInput(CustomJsonReader(ioctx))

        register_input("custom_input", input_creator)
        test_input_procedure = [
            "custom_input",
            input_creator,
            "ray.rllib.examples.custom_input_api.CustomJsonReader",
        ]
        for input_procedure in test_input_procedure:
            for fw in framework_iterator(frameworks=("torch", "tf")):
                self.write_outputs(self.test_dir, fw)
                agent = PGTrainer(env="CartPole-v0",
                                  config={
                                      "input": input_procedure,
                                      "input_config": {
                                          "input_files": self.test_dir + fw
                                      },
                                      "input_evaluation": [],
                                      "framework": fw,
                                  })
                result = agent.train()
                self.assertEqual(result["timesteps_total"], 250)
                self.assertTrue(np.isnan(result["episode_reward_mean"]))
Exemplo n.º 3
0
 def test_callbacks(self):
     for fw in framework_iterator(frameworks=("torch", "tf")):
         counts = Counter()
         pg = PGTrainer(env="CartPole-v0",
                        config={
                            "num_workers": 0,
                            "rollout_fragment_length": 50,
                            "train_batch_size": 50,
                            "callbacks": {
                                "on_episode_start":
                                lambda x: counts.update({"start": 1}),
                                "on_episode_step":
                                lambda x: counts.update({"step": 1}),
                                "on_episode_end":
                                lambda x: counts.update({"end": 1}),
                                "on_sample_end":
                                lambda x: counts.update({"sample": 1}),
                            },
                            "framework": fw,
                        })
         pg.train()
         pg.train()
         self.assertGreater(counts["sample"], 0)
         self.assertGreater(counts["start"], 0)
         self.assertGreater(counts["end"], 0)
         self.assertGreater(counts["step"], 0)
Exemplo n.º 4
0
    def do_test_nested_dict(self, make_env, test_lstm=False):
        ModelCatalog.register_custom_model("composite", DictSpyModel)
        register_env("nested", make_env)
        pg = PGTrainer(env="nested",
                       config={
                           "num_workers": 0,
                           "rollout_fragment_length": 5,
                           "train_batch_size": 5,
                           "model": {
                               "custom_model": "composite",
                               "use_lstm": test_lstm,
                           },
                           "framework": "tf",
                       })
        pg.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "d_spy_in_{}".format(i)))
            pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
            cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
            task_i = one_hot(
                DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            self.assertEqual(seen[2][0].tolist(), task_i)
Exemplo n.º 5
0
    def do_test_nested_tuple(self, make_env):
        ModelCatalog.register_custom_model("composite2", TupleSpyModel)
        register_env("nested2", make_env)
        pg = PGTrainer(
            env="nested2",
            config={
                "num_workers": 0,
                "rollout_fragment_length": 5,
                "train_batch_size": 5,
                "model": {
                    "custom_model": "composite2",
                },
                "framework": "tf",
                "disable_env_checking": True,
            },
        )
        # Skip first passes as they came from the TorchPolicy loss
        # initialization.
        TupleSpyModel.capture_index = 0
        pg.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "t_spy_in_{}".format(i)))
            pos_i = TUPLE_SAMPLES[i][0].tolist()
            cam_i = TUPLE_SAMPLES[i][1][0].tolist()
            task_i = TUPLE_SAMPLES[i][2]
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            check(seen[2][0], task_i)
Exemplo n.º 6
0
    def testAgentInputPostprocessingEnabled(self):
        self.writeOutputs(self.test_dir)

        # Rewrite the files to drop advantages and value_targets for testing
        for path in glob.glob(self.test_dir + "/*.json"):
            out = []
            for line in open(path).readlines():
                data = json.loads(line)
                del data["advantages"]
                del data["value_targets"]
                out.append(data)
            with open(path, "w") as f:
                for data in out:
                    f.write(json.dumps(data))

        agent = PGTrainer(
            env="CartPole-v0",
            config={
                "input": self.test_dir,
                "input_evaluation": [],
                "postprocess_inputs": True,  # adds back 'advantages'
            })

        result = agent.train()
        self.assertEqual(result["timesteps_total"], 250)  # read from input
        self.assertTrue(np.isnan(result["episode_reward_mean"]))
Exemplo n.º 7
0
    def test_gpus_in_non_local_mode(self):
        # Non-local mode.
        ray.init(num_cpus=8)

        actual_gpus = torch.cuda.device_count()
        print(f"Actual GPUs found (by torch): {actual_gpus}")

        config = DEFAULT_CONFIG.copy()
        config["num_workers"] = 2
        config["env"] = "CartPole-v0"

        # Expect errors when we run a config w/ num_gpus>0 w/o a GPU
        # and _fake_gpus=False.
        for num_gpus in [0, 0.1, 1, actual_gpus + 4]:
            # Only allow possible num_gpus_per_worker (so test would not
            # block infinitely due to a down worker).
            per_worker = [0] if actual_gpus == 0 or actual_gpus < num_gpus \
                else [0, 0.5, 1]
            for num_gpus_per_worker in per_worker:
                for fake_gpus in [False] + ([] if num_gpus == 0 else [True]):
                    config["num_gpus"] = num_gpus
                    config["num_gpus_per_worker"] = num_gpus_per_worker
                    config["_fake_gpus"] = fake_gpus

                    print(f"\n------------\nnum_gpus={num_gpus} "
                          f"num_gpus_per_worker={num_gpus_per_worker} "
                          f"_fake_gpus={fake_gpus}")

                    frameworks = ("tf", "torch") if num_gpus > 1 else \
                        ("tf2", "tf", "torch")
                    for _ in framework_iterator(config, frameworks=frameworks):
                        # Expect that trainer creation causes a num_gpu error.
                        if actual_gpus < num_gpus + 2 * num_gpus_per_worker \
                                and not fake_gpus:
                            # "Direct" RLlib (create Trainer on the driver).
                            # Cannot run through ray.tune.run() as it would
                            # simply wait infinitely for the resources to
                            # become available.
                            print("direct RLlib")
                            self.assertRaisesRegex(
                                RuntimeError,
                                "Found 0 GPUs on your machine",
                                lambda: PGTrainer(config, env="CartPole-v0"),
                            )
                        # If actual_gpus >= num_gpus or faked,
                        # expect no error.
                        else:
                            print("direct RLlib")
                            trainer = PGTrainer(config, env="CartPole-v0")
                            trainer.stop()
                            # Cannot run through ray.tune.run() w/ fake GPUs
                            # as it would simply wait infinitely for the
                            # resources to become available (even though, we
                            # wouldn't really need them).
                            if num_gpus == 0:
                                print("via ray.tune.run()")
                                tune.run("PG",
                                         config=config,
                                         stop={"training_iteration": 0})
        ray.shutdown()
Exemplo n.º 8
0
    def test_gpus_in_local_mode(self):
        # Local mode.
        ray.init(num_gpus=8, local_mode=True)

        actual_gpus_available = torch.cuda.device_count()

        config = DEFAULT_CONFIG.copy()
        config["num_workers"] = 2
        config["env"] = "CartPole-v0"

        # Expect no errors in local mode.
        for num_gpus in [0, 0.1, 1, actual_gpus_available + 4]:
            print(f"num_gpus={num_gpus}")
            for fake_gpus in [False, True]:
                print(f"_fake_gpus={fake_gpus}")
                config["num_gpus"] = num_gpus
                config["_fake_gpus"] = fake_gpus
                frameworks = ("tf", "torch") if num_gpus > 1 else \
                    ("tf2", "tf", "torch")
                for _ in framework_iterator(config, frameworks=frameworks):
                    print("direct RLlib")
                    trainer = PGTrainer(config, env="CartPole-v0")
                    trainer.stop()
                    print("via ray.tune.run()")
                    tune.run("PG",
                             config=config,
                             stop={"training_iteration": 0})
        ray.shutdown()
Exemplo n.º 9
0
 def test_pg_exec_impl(ray_start_regular):
     trainer = PGTrainer(env="CartPole-v0",
                         config={
                             "min_iter_time_s": 0,
                             "use_exec_api": True
                         })
     assert isinstance(trainer.train(), dict)
Exemplo n.º 10
0
    def test_multi_agent_complex_spaces(self):
        ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
        ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
        register_env("nested_ma", lambda _: NestedMultiAgentEnv())
        act_space = spaces.Discrete(2)
        pg = PGTrainer(env="nested_ma",
                       config={
                           "num_workers": 0,
                           "rollout_fragment_length": 5,
                           "train_batch_size": 5,
                           "multiagent": {
                               "policies": {
                                   "tuple_policy":
                                   (PGTFPolicy, TUPLE_SPACE, act_space, {
                                       "model": {
                                           "custom_model": "tuple_spy"
                                       }
                                   }),
                                   "dict_policy":
                                   (PGTFPolicy, DICT_SPACE, act_space, {
                                       "model": {
                                           "custom_model": "dict_spy"
                                       }
                                   }),
                               },
                               "policy_mapping_fn": lambda a: {
                                   "tuple_agent": "tuple_policy",
                                   "dict_agent": "dict_policy"
                               }[a],
                           },
                           "framework": "tf",
                       })
        # Skip first passes as they came from the TorchPolicy loss
        # initialization.
        TupleSpyModel.capture_index = DictSpyModel.capture_index = 0
        pg.train()

        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "d_spy_in_{}".format(i)))
            pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
            cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
            task_i = one_hot(
                DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            check(seen[2][0], task_i)

        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "t_spy_in_{}".format(i)))
            pos_i = TUPLE_SAMPLES[i][0].tolist()
            cam_i = TUPLE_SAMPLES[i][1][0].tolist()
            task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            check(seen[2][0], task_i)
Exemplo n.º 11
0
 def writeOutputs(self, output):
     agent = PGTrainer(env="CartPole-v0",
                       config={
                           "output": output,
                           "sample_batch_size": 250,
                       })
     agent.train()
     return agent
Exemplo n.º 12
0
 def writeOutputs(self, output):
     agent = PGTrainer(env="CartPole-v0",
                       config={
                           "output": output,
                           "rollout_fragment_length": 250,
                       })
     agent.train()
     return agent
Exemplo n.º 13
0
    def test_local(self):
        cf = DEFAULT_CONFIG.copy()
        cf["model"]["fcnet_hiddens"] = [10]

        for _ in framework_iterator(cf):
            agent = PGTrainer(cf, "CartPole-v0")
            print(agent.train())
            agent.stop()
Exemplo n.º 14
0
 def write_outputs(self, output, fw):
     agent = PGTrainer(env="CartPole-v0",
                       config={
                           "output":
                           output + (fw if output != "logdir" else ""),
                           "rollout_fragment_length": 250,
                           "framework": fw,
                       })
     agent.train()
     return agent
Exemplo n.º 15
0
 def testTrainCartpole(self):
     register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0")))
     pg = PGTrainer(env="test", config={"num_workers": 0})
     for i in range(100):
         result = pg.train()
         print("Iteration {}, reward {}, timesteps {}".format(
             i, result["episode_reward_mean"], result["timesteps_total"]))
         if result["episode_reward_mean"] >= 100:
             return
     raise Exception("failed to improve reward")
Exemplo n.º 16
0
 def test_no_step_on_init(self):
     register_env("fail", lambda _: FailOnStepEnv())
     for fw in framework_iterator(frameworks=()):
         pg = PGTrainer(
             env="fail", config={
                 "num_workers": 1,
                 "framework": fw,
             })
         self.assertRaises(Exception, lambda: pg.train())
         pg.stop()
Exemplo n.º 17
0
 def testAgentInputDir(self):
     self.writeOutputs(self.test_dir)
     agent = PGTrainer(env="CartPole-v0",
                       config={
                           "input": self.test_dir,
                           "input_evaluation": [],
                       })
     result = agent.train()
     self.assertEqual(result["timesteps_total"], 250)  # read from input
     self.assertTrue(np.isnan(result["episode_reward_mean"]))
Exemplo n.º 18
0
    def testMultiAgentComplexSpaces(self):
        ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
        ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
        register_env("nested_ma", lambda _: NestedMultiAgentEnv())
        act_space = spaces.Discrete(2)
        pg = PGTrainer(env="nested_ma",
                       config={
                           "num_workers": 0,
                           "sample_batch_size": 5,
                           "train_batch_size": 5,
                           "multiagent": {
                               "policies": {
                                   "tuple_policy":
                                   (PGTFPolicy, TUPLE_SPACE, act_space, {
                                       "model": {
                                           "custom_model": "tuple_spy"
                                       }
                                   }),
                                   "dict_policy":
                                   (PGTFPolicy, DICT_SPACE, act_space, {
                                       "model": {
                                           "custom_model": "dict_spy"
                                       }
                                   }),
                               },
                               "policy_mapping_fn": lambda a: {
                                   "tuple_agent": "tuple_policy",
                                   "dict_agent": "dict_policy"
                               }[a],
                           },
                       })
        pg.train()

        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "d_spy_in_{}".format(i)))
            pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
            cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
            task_i = one_hot(
                DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            self.assertEqual(seen[2][0].tolist(), task_i)

        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "t_spy_in_{}".format(i)))
            pos_i = TUPLE_SAMPLES[i][0].tolist()
            cam_i = TUPLE_SAMPLES[i][1][0].tolist()
            task_i = one_hot(TUPLE_SAMPLES[i][2], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            self.assertEqual(seen[2][0].tolist(), task_i)
Exemplo n.º 19
0
 def test_train_multi_cartpole_single_policy(self):
     n = 10
     register_env("multi_cartpole", lambda _: MultiCartpole(n))
     pg = PGTrainer(env="multi_cartpole", config={"num_workers": 0})
     for i in range(100):
         result = pg.train()
         print("Iteration {}, reward {}, timesteps {}".format(
             i, result["episode_reward_mean"], result["timesteps_total"]))
         if result["episode_reward_mean"] >= 50 * n:
             return
     raise Exception("failed to improve reward")
Exemplo n.º 20
0
 def test_no_step_on_init(self):
     # Allow for Unittest run.
     ray.init(num_cpus=5, ignore_reinit_error=True)
     register_env("fail", lambda _: FailOnStepEnv())
     for fw in framework_iterator(frameworks=()):
         pg = PGTrainer(env="fail",
                        config={
                            "num_workers": 1,
                            "framework": fw,
                        })
         self.assertRaises(Exception, lambda: pg.train())
Exemplo n.º 21
0
 def testAgentInputList(self):
     self.writeOutputs(self.test_dir)
     agent = PGTrainer(env="CartPole-v0",
                       config={
                           "input": glob.glob(self.test_dir + "/*.json"),
                           "input_evaluation": [],
                           "rollout_fragment_length": 99,
                       })
     result = agent.train()
     self.assertEqual(result["timesteps_total"], 250)  # read from input
     self.assertTrue(np.isnan(result["episode_reward_mean"]))
Exemplo n.º 22
0
 def test_agent_input_dir(self):
     for fw in framework_iterator(frameworks=("torch", "tf")):
         self.write_outputs(self.test_dir, fw)
         agent = PGTrainer(env="CartPole-v0",
                           config={
                               "input": self.test_dir + fw,
                               "input_evaluation": [],
                               "framework": fw,
                           })
         result = agent.train()
         self.assertEqual(result["timesteps_total"], 250)  # read from input
         self.assertTrue(np.isnan(result["episode_reward_mean"]))
Exemplo n.º 23
0
 def test_multi_agent_with_flex_agents(self):
     register_env("flex_agents_multi_agent_cartpole",
                  lambda _: FlexAgentsMultiAgent())
     pg = PGTrainer(env="flex_agents_multi_agent_cartpole",
                    config={
                        "num_workers": 0,
                        "framework": "tf",
                    })
     for i in range(10):
         result = pg.train()
         print("Iteration {}, reward {}, timesteps {}".format(
             i, result["episode_reward_mean"], result["timesteps_total"]))
Exemplo n.º 24
0
 def testAgentInputEvalSim(self):
     self.writeOutputs(self.test_dir)
     agent = PGTrainer(env="CartPole-v0",
                       config={
                           "input": self.test_dir,
                           "input_evaluation": ["simulation"],
                       })
     for _ in range(50):
         result = agent.train()
         if not np.isnan(result["episode_reward_mean"]):
             return  # simulation ok
         time.sleep(0.1)
     assert False, "did not see any simulation results"
Exemplo n.º 25
0
 def testAgentInputDict(self):
     self.writeOutputs(self.test_dir)
     agent = PGTrainer(env="CartPole-v0",
                       config={
                           "input": {
                               self.test_dir: 0.1,
                               "sampler": 0.9,
                           },
                           "train_batch_size": 2000,
                           "input_evaluation": [],
                       })
     result = agent.train()
     self.assertTrue(not np.isnan(result["episode_reward_mean"]))
Exemplo n.º 26
0
    def testMultiAgent(self):
        register_env("multi_agent_cartpole",
                     lambda _: MultiAgentCartPole({"num_agents": 10}))
        single_env = gym.make("CartPole-v0")

        def gen_policy():
            obs_space = single_env.observation_space
            act_space = single_env.action_space
            return (PGTFPolicy, obs_space, act_space, {})

        for fw in framework_iterator():
            pg = PGTrainer(
                env="multi_agent_cartpole",
                config={
                    "num_workers": 0,
                    "output": self.test_dir,
                    "multiagent": {
                        "policies": {
                            "policy_1": gen_policy(),
                            "policy_2": gen_policy(),
                        },
                        "policy_mapping_fn": (
                            lambda agent_id: random.choice(
                                ["policy_1", "policy_2"])),
                    },
                    "framework": fw,
                })
            pg.train()
            self.assertEqual(len(os.listdir(self.test_dir)), 1)

            pg.stop()
            pg = PGTrainer(
                env="multi_agent_cartpole",
                config={
                    "num_workers": 0,
                    "input": self.test_dir,
                    "input_evaluation": ["simulation"],
                    "train_batch_size": 2000,
                    "multiagent": {
                        "policies": {
                            "policy_1": gen_policy(),
                            "policy_2": gen_policy(),
                        },
                        "policy_mapping_fn": (
                            lambda agent_id: random.choice(
                                ["policy_1", "policy_2"])),
                    },
                    "framework": fw,
                })
            for _ in range(50):
                result = pg.train()
                if not np.isnan(result["episode_reward_mean"]):
                    return  # simulation ok
                time.sleep(0.1)
            assert False, "did not see any simulation results"
Exemplo n.º 27
0
    def test_multi_agent_dict_invalid_sub_values(self):
        config = {"multiagent": {"count_steps_by": "invalid_value"}}
        self.assertRaisesRegex(
            ValueError,
            "config.multiagent.count_steps_by must be",
            lambda: PGTrainer(config, env="CartPole-v0"),
        )

        config = {"multiagent": {"replay_mode": "invalid_value"}}
        self.assertRaisesRegex(
            ValueError,
            "config.multiagent.replay_mode must be",
            lambda: PGTrainer(config, env="CartPole-v0"),
        )
Exemplo n.º 28
0
 def test_agent_input_eval_sim(self):
     for fw in framework_iterator():
         self.write_outputs(self.test_dir, fw)
         agent = PGTrainer(env="CartPole-v0",
                           config={
                               "input": self.test_dir + fw,
                               "input_evaluation": ["simulation"],
                               "framework": fw,
                           })
         for _ in range(50):
             result = agent.train()
             if not np.isnan(result["episode_reward_mean"]):
                 return  # simulation ok
             time.sleep(0.1)
         assert False, "did not see any simulation results"
Exemplo n.º 29
0
 def test_agent_input_dict(self):
     for fw in framework_iterator():
         self.write_outputs(self.test_dir, fw)
         agent = PGTrainer(env="CartPole-v0",
                           config={
                               "input": {
                                   self.test_dir + fw: 0.1,
                                   "sampler": 0.9,
                               },
                               "train_batch_size": 2000,
                               "input_evaluation": [],
                               "framework": fw,
                           })
         result = agent.train()
         self.assertTrue(not np.isnan(result["episode_reward_mean"]))
Exemplo n.º 30
0
def run_with_custom_entropy_loss(args, stop):
    """Example of customizing the loss function of an existing policy.

    This performs about the same as the default loss does."""
    def entropy_policy_gradient_loss(policy, model, dist_class, train_batch):
        logits, _ = model.from_batch(train_batch)
        action_dist = dist_class(logits, model)
        if args.torch:
            # required by PGTorchPolicy's stats fn.
            policy.pi_err = torch.tensor([0.0])
            return torch.mean(-0.1 * action_dist.entropy() -
                              (action_dist.logp(train_batch["actions"]) *
                               train_batch["advantages"]))
        else:
            return (-0.1 * action_dist.entropy() - tf.reduce_mean(
                action_dist.logp(train_batch["actions"]) *
                train_batch["advantages"]))

    policy_cls = PGTorchPolicy if args.torch else PGTFPolicy
    EntropyPolicy = policy_cls.with_updates(
        loss_fn=entropy_policy_gradient_loss)

    EntropyLossPG = PGTrainer.with_updates(
        name="EntropyPG", get_policy_class=lambda _: EntropyPolicy)

    run_heuristic_vs_learned(args, use_lstm=True, trainer=EntropyLossPG)