def _test_dqn_workflow(self, use_gpu=False, use_all_avail_gpus=False): """Run DQN workflow to ensure no crashes, algorithm correctness not tested here.""" with tempfile.TemporaryDirectory() as tmpdirname: lockfile = os.path.join(tmpdirname, "multiprocess_lock") Path(lockfile).touch() params = { "training_data_path": os.path.join( curr_dir, "test_data/discrete_action/cartpole_training.json.bz2" ), "eval_data_path": os.path.join( curr_dir, "test_data/discrete_action/cartpole_eval.json.bz2" ), "state_norm_data_path": os.path.join( curr_dir, "test_data/discrete_action/cartpole_norm.json" ), "model_output_path": tmpdirname, "use_gpu": use_gpu, "use_all_avail_gpus": use_all_avail_gpus, "init_method": "file://" + lockfile, "num_nodes": 1, "node_index": 0, "actions": ["0", "1"], "epochs": 1, "rl": {}, "rainbow": {"double_q_learning": False, "dueling_architecture": False}, "training": {"minibatch_size": 128}, } dqn_workflow.main(params) predictor_files = glob.glob(tmpdirname + "/predictor_*.c2") assert len(predictor_files) == 1, "Somehow created two predictor files!" predictor = DQNPredictor.load(predictor_files[0], "minidb") test_float_state_features = [{"0": 1.0, "1": 1.0, "2": 1.0, "3": 1.0}] q_values = predictor.predict(test_float_state_features) assert len(q_values[0].keys()) == 2
def _test_dqn_workflow(self, use_gpu=False, use_all_avail_gpus=False): """Run DQN workflow to ensure no crashes, algorithm correctness not tested here.""" with tempfile.TemporaryDirectory() as tmpdirname: params = { "training_data_path": os.path.join( curr_dir, "test_data/discrete_action/cartpole_training.json.bz2" ), "eval_data_path": os.path.join( curr_dir, "test_data/discrete_action/cartpole_eval.json.bz2" ), "state_norm_data_path": os.path.join( curr_dir, "test_data/discrete_action/cartpole_norm.json" ), "model_output_path": tmpdirname, "use_gpu": use_gpu, "use_all_avail_gpus": use_all_avail_gpus, "actions": ["0", "1"], "epochs": 1, "rl": {}, "rainbow": {}, "training": {"minibatch_size": 128}, } predictor = dqn_workflow.main(params) test_float_state_features = [{"0": 1.0, "1": 1.0, "2": 1.0, "3": 1.0}] q_values = predictor.predict(test_float_state_features) assert len(q_values[0].keys()) == 2