def _make_experiment(exp_id=1, path="./Results/Tmp/test_PST"): """ Each file specifying an experimental setup should contain a make_experiment function which returns an instance of the Experiment class with everything set up. @param id: number used to seed the random number generators @param path: output directory where logs and results are stored """ # Domain: NUM_UAV = 3 domain = PST(NUM_UAV=NUM_UAV) # Representation # discretization only needed for continuous state spaces, discarded otherwise representation = IncrementalTabular(domain) # Policy policy = eGreedy(representation, epsilon=0.1) # Agent agent = SARSA( representation=representation, policy=policy, discount_factor=domain.discount_factor, initial_learn_rate=0.1, ) checks_per_policy = 2 max_steps = 30 num_policy_checks = 2 experiment = Experiment(**locals()) return experiment
def test_sarsa_valfun_chain(): """ Check if SARSA computes the value function of a simple Markov chain correctly. This only tests value function estimation, only one action possible """ rep = MockRepresentation() pol = eGreedy(rep) agent = SARSA(pol, rep, 0.9, lambda_=0.0) for i in range(1000): if i % 4 == 3: continue agent.learn( np.array([i % 4]), [0], 0, 1.0, np.array([(i + 1) % 4]), [0], 0, (i + 2) % 4 == 0, ) V_true = np.array([2.71, 1.9, 1, 0]) np.testing.assert_allclose(rep.weight_vec, V_true)
def _make_experiment(domain, exp_id=1, path="./Results/Tmp/test_InfTrackCartPole"): ## Representation # discretization only needed for continuous state spaces, discarded otherwise representation = Tabular(domain) ## Policy policy = eGreedy(representation, epsilon=0.2) ## Agent agent = SARSA( representation=representation, policy=policy, discount_factor=domain.discount_factor, initial_learn_rate=0.1, ) checks_per_policy = 3 max_steps = 50 num_policy_checks = 3 experiment = Experiment(**locals()) return experiment
def test_deepcopy(): rep = MockRepresentation() pol = eGreedy(rep) agent = SARSA(pol, rep, 0.9, lambda_=0.0) copied_agent = copy.deepcopy(agent) assert agent.lambda_ == copied_agent.lambda_
def tabular_sarsa(domain, discretization=20, lambda_=0.3): tabular = Tabular(domain, discretization=discretization) policy = eGreedy(tabular, epsilon=0.1) return SARSA(policy, tabular, domain.discount_factor, lambda_=lambda_)