def _test_parametric_dqn_workflow(self, use_gpu=False, use_all_avail_gpus=False): """Run Parametric DQN workflow to ensure no crashes, algorithm correctness not tested here.""" with tempfile.TemporaryDirectory() as tmpdirname: params = { "training_data_path": os.path.join( curr_dir, "test_data/parametric_action/cartpole_training.json.bz2" ), "eval_data_path": os.path.join( curr_dir, "test_data/parametric_action/cartpole_eval.json.bz2" ), "state_norm_data_path": os.path.join( curr_dir, "test_data/parametric_action/state_features_norm.json" ), "action_norm_data_path": os.path.join( curr_dir, "test_data/parametric_action/action_norm.json" ), "model_output_path": tmpdirname, "use_gpu": use_gpu, "use_all_avail_gpus": use_all_avail_gpus, "epochs": 1, "rl": {}, "rainbow": {}, "training": {"minibatch_size": 128}, } predictor = parametric_dqn_workflow.train_network(params) test_float_state_features = [{"0": 1.0, "1": 1.0, "2": 1.0, "3": 1.0}] test_int_state_features = [{}] test_action_features = [{"4": 0.0, "5": 1.0}] q_values = predictor.predict( test_float_state_features, test_int_state_features, test_action_features ) assert len(q_values[0].keys()) == 1
def test_parametric_dqn_workflow(self): """Run Parametric DQN workflow to ensure no crashes, algorithm correctness not tested here.""" params = { "training_data_path": os.path.join( curr_dir, "test_data/parametric_action/cartpole_training_data.json"), "state_norm_data_path": os.path.join( curr_dir, "test_data/parametric_action/state_features_norm.json"), "action_norm_data_path": os.path.join(curr_dir, "test_data/parametric_action/action_norm.json"), "model_output_path": None, "use_gpu": False, "epochs": 1, "rl": {}, "rainbow": {}, "training": { "minibatch_size": 16 }, "in_training_cpe": None, } predictor = parametric_dqn_workflow.train_network(params) test_float_state_features = [{"0": 1.0, "1": 1.0, "2": 1.0, "3": 1.0}] test_int_state_features = [{}] test_action_features = [{"4": 0.0, "5": 1.0}] q_values = predictor.predict(test_float_state_features, test_int_state_features, test_action_features) assert len(q_values[0].keys()) == 1