Exemplo n.º 1
0
from introrl.utils.running_ave import RunningAve
from introrl.mc_funcs.mc_ev_prediction import mc_every_visit_prediction
from introrl.policy import Policy
from introrl.agent_supt.state_value_coll import StateValueColl
from introrl.agent_supt.nstep_td_eval_walker import NStepTDWalker
from introrl.mdp_data.random_walk_generic_mrp import get_random_walk
from introrl.agent_supt.episode_maker import make_episode

GAMMA=1.0

AVE_OVER = 100

rw_mrp = get_random_walk(Nside_states=9, win_reward=1.0, lose_reward=-1.0, step_reward=0.0)
policy = Policy( environment=rw_mrp )

policy.intialize_policy_to_equiprobable() # should be equiprobable from above init already

episode_obj = make_episode( 'C', policy, rw_mrp )

fig, ax = plt.subplots()

# ---------------- set up true value data for RMS calc --------------------
true_valueD = {'C':0.0} # { 'Win':0.0, 'Lose':0.0}

#print('rw_mrp.get_num_states() = ',rw_mrp.get_num_states())
delta = 2.0 / (rw_mrp.get_num_states()-1)
Nsides = int( rw_mrp.get_num_states() / 2) - 1
d = 0.0
for i in range(1, Nsides+1 ):
    d += delta
    true_valueD[ 'L-%i'%i] = float('%g'%-d) # I got mad about the small bits.
Exemplo n.º 2
0
class MyTest(unittest.TestCase):
    def setUp(self):
        unittest.TestCase.setUp(self)
        self.gridworld = get_gridworld()
        self.P = Policy(environment=self.gridworld)
        self.P.intialize_policy_to_equiprobable(env=self.gridworld)

    def tearDown(self):
        unittest.TestCase.tearDown(self)
        del (self.P)

    def test_should_always_pass_cleanly(self):
        """Should always pass cleanly."""
        pass

    def test_myclass_existence(self):
        """Check that myclass exists"""

        # See if the self.P object exists
        self.assertIsInstance(self.P, Policy, msg=None)

    def test_set_policy_from_default_pi(self):
        """test set policy from default pi"""

        policyD = self.gridworld.get_default_policy_desc_dict()
        self.P.set_policy_from_piD(policyD)

        self.assertEqual(self.P.get_action_prob((2, 2), 'U'), 1.0)
        self.assertEqual(self.P.get_action_prob((2, 2), 'R'), 0.0)
        self.assertEqual(self.P.get_action_prob((2, 2), 'D'), None)

    #def test_set_policy_from_list_of_actions(self):
    #    """test set policy from list of actions"""
    #    piD = {(0, 0):('R','D') }
    #    self.P.set_policy_from_piD( piD )

    #    self.assertEqual(self.P.get_action_prob( (0,0), 'U'), None)
    #    self.assertEqual(self.P.get_action_prob( (0,0), 'R'), 0.5)
    #    self.assertEqual(self.P.get_action_prob( (0,0), 'D'), 0.5)

    #def test_set_policy_from_list_of_action_probs(self):
    #    """test set policy from list of action probs"""
    #    piD = {(0, 0):[('R',0.6), ('D',0.4)] }
    #    self.P.set_policy_from_piD( piD )

    #    self.assertEqual(self.P.get_action_prob( (0,0), 'U'), None)
    #    self.assertEqual(self.P.get_action_prob( (0,0), 'R'), 0.6)
    #    self.assertEqual(self.P.get_action_prob( (0,0), 'D'), 0.4)

    #    # make (action, prob) entry too long.
    #    with self.assertRaises(ValueError):
    #        piD = {(0, 0):[('R',0.6,0.4), ('D',0.4,0.6)] }
    #        self.P.set_policy_from_piD( piD )

    def test_learn_all_s_and_a(self):
        """test learn all s and a"""

        self.P.learn_all_states_and_actions_from_env(self.gridworld)

    def test_initialize_to_random(self):
        """test initialize to random"""

        self.P.intialize_policy_to_random(env=self.gridworld)
        apL = self.P.get_list_of_all_action_desc_prob((0, 2),
                                                      incl_zero_prob=True)
        pL = [p for (adesc, p) in apL]
        self.assertEqual(sorted(pL), [0.0, 0.0, 1.0])

    def test_iterate_adesc_p(self):
        """test iterate adesc p"""

        apL = []
        for (a_desc, p) in self.P.iter_policy_ap_for_state(
            (0, 0), incl_zero_prob=False):
            apL.append((a_desc, p))

        self.assertIn(('R', 0.5), apL)
        self.assertIn(('D', 0.5), apL)
        self.assertNotIn(('U', 0.5), apL)

    def test_iterate_all_states(self):
        """test iterate all states"""

        sL = []
        for s_hash in self.P.iter_all_policy_states():
            sL.append(s_hash)
        sL.sort()
        self.assertEqual(len(sL), 9)
        self.assertEqual(sL[0], (0, 0))
        self.assertEqual(sL[-1], (2, 3))

    def test_get_single_action(self):
        """test get single action"""
        a_desc = self.P.get_single_action((0, 0))
        self.assertIn(a_desc, ('R', 'D'))

        a_desc = self.P.get_single_action((99, 99))
        self.assertEqual(a_desc, None)
Exemplo n.º 3
0
from introrl.black_box_sims.random_walk_1000 import RandomWalk_1000Simulation
from introrl.agent_supt.episode_maker import make_episode
from introrl.policy import Policy

NUM_EPISODES = 100000
countD = {} # index=state, value=count 

RW = RandomWalk_1000Simulation()
policy = Policy(environment=RW)
policy.intialize_policy_to_equiprobable( env=RW )


for Nepi in range(NUM_EPISODES):
    episode = make_episode(500, policy, RW, max_steps=10000)
    
    for dr in episode.get_rev_discounted_returns( gamma=1.0 ):
        (s_hash, a_desc, reward, sn_hash, G) = dr
        
        countD[ s_hash ] = countD.get( s_hash, 0 ) + 1

SUM_VISITS = sum( list(countD.values()) )
freqL = []
for i in range(1,1001):
    freqL.append( countD.get(i,0) / float(SUM_VISITS) )

# copy and paste list into plot script
print('freqL =', repr(freqL))
Exemplo n.º 4
0
from introrl.dp_funcs.dp_policy_eval import dp_policy_evaluation
from introrl.policy import Policy
from introrl.state_values import StateValues
from introrl.mdp_data.sutton_ex4_1_grid import get_gridworld

gridworld = get_gridworld()

pi = Policy(environment=gridworld)
pi.intialize_policy_to_equiprobable(env=gridworld)

sv = StateValues(gridworld)
sv.init_Vs_to_zero()

dp_policy_evaluation(pi,
                     sv,
                     max_iter=1000,
                     err_delta=0.001,
                     gamma=1.,
                     fmt_V='%.1f')

#sv.summ_print( fmt_V='%.3f', show_states=False )
pi.summ_print(environment=gridworld, verbosity=0, show_env_states=False)

#print( gridworld.get_info() )