Ejemplo n.º 1
0
    def testPyTorchModel(self):
        ModelCatalog.register_custom_model("composite", TorchSpyModel)
        register_env("nested", lambda _: NestedDictEnv())
        a2c = A2CAgent(
            env="nested",
            config={
                "num_workers": 0,
                "use_pytorch": True,
                "sample_batch_size": 5,
                "train_batch_size": 5,
                "model": {
                    "custom_model": "composite",
                },
            })

        a2c.train()

        # Check that the model sees the correct reconstructed observations
        for i in range(4):
            seen = pickle.loads(
                ray.experimental.internal_kv._internal_kv_get(
                    "torch_spy_in_{}".format(i)))
            pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
            cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
            task_i = one_hot(
                DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
            self.assertEqual(seen[0][0].tolist(), pos_i)
            self.assertEqual(seen[1][0].tolist(), cam_i)
            self.assertEqual(seen[2][0].tolist(), task_i)
Ejemplo n.º 2
0
 def testGlobalVarsUpdate(self):
     agent = A2CAgent(env="CartPole-v0",
                      config={
                          "lr_schedule": [[0, 0.1], [400, 0.000001]],
                      })
     result = agent.train()
     self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
     result2 = agent.train()
     self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)