def test_Packets(self): self.assertTrue( (self.policy._req_S, self.policy._req_A, self.policy._req_R) == self.policy.PacketSizeReq() ) pkt = ExpPacket([],[],[]) self.assertFalse(self.policy.IsValidPacket(pkt)) # depth of 0, not good enough pkt.Push( (0,0), 0, -1 ) self.assertFalse(self.policy.IsValidPacket(pkt)) # Depth of 1, not good enough pkt.Push( (1,1), 2, -1 ) self.assertTrue(self.policy.IsValidPacket(pkt)) # Has depth of 2, should be good now self.assertFalse(self.policy.IsValidPacket([])) # Make sure it returns False instead of throwing TypeError
def test_GetTargetEstimate(self): s_list = [(0,0), (1,2)] a_list = [0, 1] r_list = [-3, -5] self.packet = ExpPacket(s_list, a_list, r_list) G = self.policy.GetTargetEstimate(self.packet) self.assertTrue(G == -3 + 1 * self.policy.GetStateVal((1,2), 1)) s_list = [(0,34), (1,343)] a_list = [0, 1] r_list = [-3, -5] self.packet = ExpPacket(s_list, a_list, r_list) with self.assertRaises(IndexError): self.policy.GetTargetEstimate(self.packet)
def test_ImprovePolicy(self): pkt = ExpPacket([],[],[]) old_val = self.policy.GetStateVal((0,0), 0) val_1_1_2 = self.policy.GetStateVal((1,1), 2) alpha = self.policy.alpha gamma = self.policy.gamma self.assertFalse(self.policy.ImprovePolicy(pkt)) # Empty packet, not enough depth to improve policy self.assertFalse(self.policy.ImprovePolicy([])) # Make sure it returns False instead of throwing TypeError pkt.Push( (0,0), 0, -1 ) self.assertFalse(self.policy.ImprovePolicy(pkt)) # Depth of 1, not good enough pkt.Push( (1,1), 2, -1 ) self.assertTrue(self.policy.ImprovePolicy(pkt)) # Depth of 1, not good enough new_val = self.policy.GetStateVal((0,0), 0) self.assertNotEqual(old_val, new_val) # Make sure (0,0), 0 got updated self.assertAlmostEqual(new_val, (1-alpha)*old_val + alpha*(-1 + gamma*val_1_1_2) ) # make sure it follows the SARSA update
def test_InitAndLen(self): # Test that constructor with args works and Len is correct self.sl = [2,3,4] self.al = [0,0,1] self.rl = [-1,-1] _ = ExpPacket( self.sl, self.al, self.rl ) self.assertTrue(self.pckt_f.LenS() == 3) self.assertTrue(self.pckt_f.LenA() == 3) self.assertTrue(self.pckt_f.LenR() == 2) # Test that constructor with no args works and Len is correct self.assertTrue(self.pckt_e.LenS() == 0) self.assertTrue(self.pckt_e.LenA() == 0) self.assertTrue(self.pckt_e.LenR() == 0) # Test that constructor accepts diff types of lists ExpPacket( [2,3,4], (0,0,1), np.array((-1,-1)) )
def test_ImprovePolicy(self): # make sure that policy can't be improved with empty packet self.assertFalse(self.agent.ImprovePolicy( ExpPacket() )) # improve the policy and check that the corresponding action is now selected pkt = ExpPacket([(0,0),(1,1)],[0, 0],[100]) self.agent.ImprovePolicy( pkt ) self.assertTrue(self.agent.GetAction((0,0)) == 0) # improve again, with diff action, and check that it is now the best pkt = ExpPacket([(0,0),(1,0)],[2, 2],[1000]) self.agent.ImprovePolicy( pkt ) self.assertTrue(self.agent.GetAction((0,0)) == 2) # Improve again, with negative reward for same action, so its no longer best pkt = ExpPacket([(0,0),(1,0)],[2, 2],[-10000]) self.agent.ImprovePolicy(pkt) self.assertFalse(self.agent.GetAction((0,0)) == 2)
def setUp(self): self.sl = [2,3,4] self.al = [0,0,1] self.rl = [-1,-1] self.pckt_f = ExpPacket( self.sl, self.al, self.rl ) # filled packet self.pckt_e = ExpPacket() # Empty packet self.pckt_max = ExpPacket( np.zeros(ExpPacket.MAX_SIZE), np.zeros(ExpPacket.MAX_SIZE), np.zeros(ExpPacket.MAX_SIZE) ) # Max packet
def GetLatestAsPacket(self, nS, nA, nR): """ Same as GetLatest, but returns a new ExpPacket with sublists instead Returns: [ExpPacket] if a valid one can be generated, [None] otherwise """ try: latest = self.GetLatest(nS,nA,nR) except IndexError: return None sl, al, rl = latest return ExpPacket(sl, al, rl)
def __init__(self, world, agents): """ Initializes an RLGame object, used to run episodes and train agents Params: - world: [World] the world that this game operates in (must be subclass of type World) - agents: [list] all the agents (initialized) participating in this game (must be list of type Agent) """ self._world = world self._agents = {} # Create dictionary of agents for agent in agents: self._agents[agent.GetID()] = agent self._history = {} # Build dictionary for history for id, _ in self._agents.items(): self._history[id] = ExpPacket() self._episodes = []
class TestAgent(unittest.TestCase): def setUp(self): self.sl = [2,3,4] self.al = [0,0,1] self.rl = [-1,-1] self.pckt_f = ExpPacket( self.sl, self.al, self.rl ) # filled packet self.pckt_e = ExpPacket() # Empty packet self.pckt_max = ExpPacket( np.zeros(ExpPacket.MAX_SIZE), np.zeros(ExpPacket.MAX_SIZE), np.zeros(ExpPacket.MAX_SIZE) ) # Max packet def test_InitAndLen(self): # Test that constructor with args works and Len is correct self.sl = [2,3,4] self.al = [0,0,1] self.rl = [-1,-1] _ = ExpPacket( self.sl, self.al, self.rl ) self.assertTrue(self.pckt_f.LenS() == 3) self.assertTrue(self.pckt_f.LenA() == 3) self.assertTrue(self.pckt_f.LenR() == 2) # Test that constructor with no args works and Len is correct self.assertTrue(self.pckt_e.LenS() == 0) self.assertTrue(self.pckt_e.LenA() == 0) self.assertTrue(self.pckt_e.LenR() == 0) # Test that constructor accepts diff types of lists ExpPacket( [2,3,4], (0,0,1), np.array((-1,-1)) ) def test_Push(self): # Test that we can push diff types of ints, floats and lists self.assertTrue(self.pckt_f.Push(2,6,1)) self.assertTrue(self.pckt_f.Push(1.1,0.35,-2.32)) self.assertTrue(self.pckt_f.Push( (1,3) , [0, 5.3] , np.array([-2,-2.21,-1]) )) # Make sure Len is still correct self.assertTrue(self.pckt_f.LenS() == 6) self.assertTrue(self.pckt_f.LenA() == 6) self.assertTrue(self.pckt_f.LenR() == 5) # Check that we get false when trying to push into a max packet self.assertFalse(self.pckt_max.Push(2,6,1)) def test_Get(self): # Make sure we can get the list in packet correctly sl, al, rl = self.pckt_f.Get() self.assertTrue( np.all( sl == self.sl ) ) self.assertTrue( np.all( al == self.al ) ) self.assertTrue( np.all( rl == self.rl ) ) # Push some new exp into packet self.pckt_f.Push(2,6,1) self.sl.append(2) self.al.append(6) self.rl.append(1) # Make sure we can get the list in packet correctly after pushing sl, al, rl = self.pckt_f.Get() self.assertTrue( np.all( sl == self.sl ) ) self.assertTrue( np.all( al == self.al ) ) self.assertTrue( np.all( rl == self.rl ) ) # Make sure we can grab the latest packets properly sl, al, rl = self.pckt_f.GetLatest(2,2,2) self.assertTrue( np.all( sl == self.sl[-2:] ) ) self.assertTrue( np.all( al == self.al[-2:] ) ) self.assertTrue( np.all( rl == self.rl[-2:] ) ) # Grab the curr length of the lists slen = self.pckt_f.LenS() alen = self.pckt_f.LenA() rlen = self.pckt_f.LenR() # Make sure GetLatest works with Len_() funcs sl, al, rl = self.pckt_f.GetLatest(slen, alen, rlen) # Make sure we get an IndexError if we try to access Latest > Num_Avail with self.assertRaises(IndexError): sl, al, rl = self.pckt_f.GetLatest(slen+1, alen, rlen) with self.assertRaises(IndexError): self.pckt_f.GetLatest(slen, alen+1, rlen) with self.assertRaises(IndexError): self.pckt_f.GetLatest(slen, alen, rlen+1) def test_IsReqDepth(self): # Make sure IsReqDepth reports true when indices <= max for each list self.assertTrue( self.pckt_f.IsReqDepth(1,1,1) ) self.assertTrue( self.pckt_f.IsReqDepth(2,2,2) ) self.assertTrue( self.pckt_f.IsReqDepth(3,3,2) ) # Make sure IsReqDepth reports false if ANY index is > max for that list self.assertFalse( self.pckt_f.IsReqDepth(1,1,4) ) self.assertFalse( self.pckt_f.IsReqDepth(1,6,1) ) self.assertFalse( self.pckt_f.IsReqDepth(8,1,1) ) def test_GetLatestAsPacket(self): # Make sure we get back an ExpPacket self.assertTrue( isinstance( self.pckt_f.GetLatestAsPacket(1,1,1), self.pckt_f.__class__)) # Get the sub packet of slice of latest exp sub_packet = self.pckt_f.GetLatestAsPacket(1,2,1) self.assertTrue( sub_packet.LenS() == 1 ) self.assertTrue( sub_packet.LenA() == 2 ) self.assertTrue( sub_packet.LenR() == 1 ) # Get the lists from the subpack and compare to sub_lists from original packet sl, al, rl = sub_packet.Get() sl_orig, al_orig, rl_orig = self.pckt_f.GetLatest(1,2,1) self.assertTrue( sl == sl_orig ) self.assertTrue( al == al_orig ) self.assertTrue( rl == rl_orig ) def test_Print(self): try: print(self.pckt_f) except Exception as e: self.fail(f"Encountered error {e}")
def test_NotImplementedIsValidPacket(self): with self.assertRaises(NotImplementedError): self.policy.IsValidPacket( ExpPacket([], [], []) )
def test_NotImplementedImprovePolicy(self): with self.assertRaises(NotImplementedError): self.policy.ImprovePolicy( ExpPacket([], [], []) )
def test_NotImplementedGetTargetEstimate(self): with self.assertRaises(NotImplementedError): self.policy.GetTargetEstimate( ExpPacket([], [], []) ) with self.assertRaises(NotImplementedError): self.policy.GetTargetEstimate( ExpPacket([(0,0),(3,1),(8,1)], [0,2,1], [-1,-1,0]) )