def testPyTorchModel(self): ModelCatalog.register_custom_model("composite", TorchSpyModel) register_env("nested", lambda _: NestedDictEnv()) a2c = A2CAgent( env="nested", config={ "num_workers": 0, "use_pytorch": True, "sample_batch_size": 5, "train_batch_size": 5, "model": { "custom_model": "composite", }, }) a2c.train() # Check that the model sees the correct reconstructed observations for i in range(4): seen = pickle.loads( ray.experimental.internal_kv._internal_kv_get( "torch_spy_in_{}".format(i))) pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() task_i = one_hot( DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) self.assertEqual(seen[0][0].tolist(), pos_i) self.assertEqual(seen[1][0].tolist(), cam_i) self.assertEqual(seen[2][0].tolist(), task_i)
def testGlobalVarsUpdate(self): agent = A2CAgent(env="CartPole-v0", config={ "lr_schedule": [[0, 0.1], [400, 0.000001]], }) result = agent.train() self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01) result2 = agent.train() self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)