Пример #1
0
 def _test_dqn_workflow(self, use_gpu=False, use_all_avail_gpus=False):
     """Run DQN workflow to ensure no crashes, algorithm correctness
     not tested here."""
     with tempfile.TemporaryDirectory() as tmpdirname:
         lockfile = os.path.join(tmpdirname, "multiprocess_lock")
         Path(lockfile).touch()
         params = {
             "training_data_path": os.path.join(
                 curr_dir, "test_data/discrete_action/cartpole_training.json.bz2"
             ),
             "eval_data_path": os.path.join(
                 curr_dir, "test_data/discrete_action/cartpole_eval.json.bz2"
             ),
             "state_norm_data_path": os.path.join(
                 curr_dir, "test_data/discrete_action/cartpole_norm.json"
             ),
             "model_output_path": tmpdirname,
             "use_gpu": use_gpu,
             "use_all_avail_gpus": use_all_avail_gpus,
             "init_method": "file://" + lockfile,
             "num_nodes": 1,
             "node_index": 0,
             "actions": ["0", "1"],
             "epochs": 1,
             "rl": {},
             "rainbow": {"double_q_learning": False, "dueling_architecture": False},
             "training": {"minibatch_size": 128},
         }
         dqn_workflow.main(params)
         predictor_files = glob.glob(tmpdirname + "/predictor_*.c2")
         assert len(predictor_files) == 1, "Somehow created two predictor files!"
         predictor = DQNPredictor.load(predictor_files[0], "minidb")
         test_float_state_features = [{"0": 1.0, "1": 1.0, "2": 1.0, "3": 1.0}]
         q_values = predictor.predict(test_float_state_features)
     assert len(q_values[0].keys()) == 2
Пример #2
0
 def _test_dqn_workflow(self, use_gpu=False, use_all_avail_gpus=False):
     """Run DQN workflow to ensure no crashes, algorithm correctness
     not tested here."""
     with tempfile.TemporaryDirectory() as tmpdirname:
         params = {
             "training_data_path": os.path.join(
                 curr_dir, "test_data/discrete_action/cartpole_training.json.bz2"
             ),
             "eval_data_path": os.path.join(
                 curr_dir, "test_data/discrete_action/cartpole_eval.json.bz2"
             ),
             "state_norm_data_path": os.path.join(
                 curr_dir, "test_data/discrete_action/cartpole_norm.json"
             ),
             "model_output_path": tmpdirname,
             "use_gpu": use_gpu,
             "use_all_avail_gpus": use_all_avail_gpus,
             "actions": ["0", "1"],
             "epochs": 1,
             "rl": {},
             "rainbow": {},
             "training": {"minibatch_size": 128},
         }
         predictor = dqn_workflow.main(params)
         test_float_state_features = [{"0": 1.0, "1": 1.0, "2": 1.0, "3": 1.0}]
         q_values = predictor.predict(test_float_state_features)
     assert len(q_values[0].keys()) == 2