def test_gridworld_sequential_adapter(self): """ Create a gridworld environment, logging policy, and target policy Evaluates target policy using the direct OPE sequential doubly robust estimator, then transforms the log into an evaluation data page which is passed to the ope adapter. This test is meant to verify the adaptation of EDPs into RLEstimatorInputs as employed by ReAgent since ReAgent provides EDPs to Evaluators. Going from EDP -> RLEstimatorInput is more involved than RLEstimatorInput -> EDP since the EDP does not store the state at each timestep in each MDP, only the corresponding logged outputs & model outputs. Thus, the adapter must do some tricks to represent these timesteps as states so the ope module can extract the correct outputs. Note that there is some randomness in the model outputs since the model is purposefully noisy. However, the same target policy is being evaluated on the same logged walks through the gridworld, so the two results should be close in value (within 1). """ random.seed(0) np.random.seed(0) torch.random.manual_seed(0) device = torch.device("cuda") if torch.cuda.is_available() else None gridworld = GridWorld.from_grid( [ ["s", "0", "0", "0", "0"], ["0", "0", "0", "W", "0"], ["0", "0", "0", "0", "0"], ["0", "W", "0", "0", "0"], ["0", "0", "0", "0", "g"], ], max_horizon=TestOPEModuleAlgs.MAX_HORIZON, ) action_space = ActionSpace(4) opt_policy = TabularPolicy(action_space) trainer = DPTrainer(gridworld, opt_policy) value_func = trainer.train(gamma=TestOPEModuleAlgs.GAMMA) behavivor_policy = RandomRLPolicy(action_space) target_policy = EpsilonGreedyRLPolicy(opt_policy, TestOPEModuleAlgs.NOISE_EPSILON) model = NoiseGridWorldModel( gridworld, action_space, epsilon=TestOPEModuleAlgs.NOISE_EPSILON, max_horizon=TestOPEModuleAlgs.MAX_HORIZON, ) value_func = DPValueFunction(target_policy, model, TestOPEModuleAlgs.GAMMA) ground_truth = DPValueFunction(target_policy, gridworld, TestOPEModuleAlgs.GAMMA) log = [] log_generator = PolicyLogGenerator(gridworld, behavivor_policy) num_episodes = TestOPEModuleAlgs.EPISODES for state in gridworld.states: for _ in range(num_episodes): log.append(log_generator.generate_log(state)) estimator_input = RLEstimatorInput( gamma=TestOPEModuleAlgs.GAMMA, log=log, target_policy=target_policy, value_function=value_func, ground_truth=ground_truth, ) edp = rlestimator_input_to_edp(estimator_input, len(model.action_space)) dr_estimator = SeqDREstimator(weight_clamper=None, weighted=False, device=device) module_results = SequentialOPEstimatorAdapter.estimator_results_to_cpe_estimate( dr_estimator.evaluate(estimator_input)) adapter_results = SequentialOPEstimatorAdapter( dr_estimator, TestOPEModuleAlgs.GAMMA, device=device).estimate(edp) self.assertAlmostEqual( adapter_results.raw, module_results.raw, delta=TestOPEModuleAlgs.CPE_PASS_BAR, ), f"OPE adapter results differed too much from underlying module (Diff: {abs(adapter_results.raw - module_results.raw)} > {TestOPEModuleAlgs.CPE_PASS_BAR})" self.assertLess( adapter_results.raw, TestOPEModuleAlgs.CPE_MAX_VALUE ), f"OPE adapter results are too large ({adapter_results.raw} > {TestOPEModuleAlgs.CPE_MAX_VALUE})"
action_space = ActionSpace(4) opt_policy = TabularPolicy(action_space) trainer = DPTrainer(gridworld, opt_policy) value_func = trainer.train(gamma=GAMMA) logging.info(f"Opt Policy:\n{gridworld.dump_policy(opt_policy)}") logging.info(f"Opt state values:\n{gridworld.dump_value_func(value_func)}") behavivor_policy = RandomRLPolicy(action_space) target_policy = EpsilonGreedyRLPolicy(opt_policy, 0.3) model = NoiseGridWorldModel(gridworld, action_space, epsilon=0.3, max_horizon=1000) value_func = DPValueFunction(target_policy, model, GAMMA) ground_truth = DPValueFunction(target_policy, gridworld, GAMMA) logging.info(f"Target Policy ground truth values:\n" f"{gridworld.dump_value_func(ground_truth)}") log = {} log_generator = PolicyLogGenerator(gridworld, behavivor_policy) num_episodes = 200 for state in gridworld.states: mdps = [] for _ in range(num_episodes): mdps.append(log_generator.generate_log(state)) log[state] = mdps logging.info(f"Generated {len(mdps)} logs for {state}")