示例#1
0
		def test_Packets(self):
			self.assertTrue( (self.policy._req_S, self.policy._req_A, self.policy._req_R) == self.policy.PacketSizeReq() )

			pkt = ExpPacket([],[],[])
			self.assertFalse(self.policy.IsValidPacket(pkt)) # depth of 0, not good enough
			pkt.Push( (0,0), 0, -1 )
			self.assertFalse(self.policy.IsValidPacket(pkt)) # Depth of 1, not good enough
			pkt.Push( (1,1), 2, -1 )
			self.assertTrue(self.policy.IsValidPacket(pkt)) # Has depth of 2, should be good now

			self.assertFalse(self.policy.IsValidPacket([])) # Make sure it returns False instead of throwing TypeError
示例#2
0
		def test_GetTargetEstimate(self):

			s_list = [(0,0), (1,2)]
			a_list = [0, 1]
			r_list = [-3, -5]
			self.packet = ExpPacket(s_list, a_list, r_list)
			G = self.policy.GetTargetEstimate(self.packet)
			self.assertTrue(G == -3 + 1 * self.policy.GetStateVal((1,2), 1))

			s_list = [(0,34), (1,343)]
			a_list = [0, 1]
			r_list = [-3, -5]
			self.packet = ExpPacket(s_list, a_list, r_list)
			with self.assertRaises(IndexError):
				self.policy.GetTargetEstimate(self.packet)
示例#3
0
		def test_ImprovePolicy(self):
			pkt = ExpPacket([],[],[])

			old_val = self.policy.GetStateVal((0,0), 0)
			val_1_1_2 = self.policy.GetStateVal((1,1), 2)
			alpha = self.policy.alpha
			gamma = self.policy.gamma

			self.assertFalse(self.policy.ImprovePolicy(pkt)) # Empty packet, not enough depth to improve policy
			self.assertFalse(self.policy.ImprovePolicy([])) # Make sure it returns False instead of throwing TypeError
			pkt.Push( (0,0), 0, -1 )
			self.assertFalse(self.policy.ImprovePolicy(pkt)) # Depth of 1, not good enough
			pkt.Push( (1,1), 2, -1 )
			self.assertTrue(self.policy.ImprovePolicy(pkt)) # Depth of 1, not good enough

			new_val = self.policy.GetStateVal((0,0), 0)
			self.assertNotEqual(old_val, new_val) # Make sure (0,0), 0 got updated
			self.assertAlmostEqual(new_val, (1-alpha)*old_val + alpha*(-1 + gamma*val_1_1_2) ) # make sure it follows the SARSA update
示例#4
0
		def test_InitAndLen(self):

			# Test that constructor with args works and Len is correct
			self.sl = [2,3,4]
			self.al = [0,0,1]
			self.rl = [-1,-1]
			_ = ExpPacket( self.sl, self.al, self.rl )
			self.assertTrue(self.pckt_f.LenS() == 3)
			self.assertTrue(self.pckt_f.LenA() == 3)
			self.assertTrue(self.pckt_f.LenR() == 2)

			# Test that constructor with no args works and Len is correct
			self.assertTrue(self.pckt_e.LenS() == 0)
			self.assertTrue(self.pckt_e.LenA() == 0)
			self.assertTrue(self.pckt_e.LenR() == 0)

			# Test that constructor accepts diff types of lists
			ExpPacket( [2,3,4], (0,0,1), np.array((-1,-1)) )
示例#5
0
		def test_ImprovePolicy(self):

			# make sure that policy can't be improved with empty packet
			self.assertFalse(self.agent.ImprovePolicy( ExpPacket() ))

			# improve the policy and check that the corresponding action is now selected
			pkt = ExpPacket([(0,0),(1,1)],[0, 0],[100])
			self.agent.ImprovePolicy( pkt )
			self.assertTrue(self.agent.GetAction((0,0)) == 0)

			# improve again, with diff action, and check that it is now the best
			pkt = ExpPacket([(0,0),(1,0)],[2, 2],[1000])
			self.agent.ImprovePolicy( pkt )
			self.assertTrue(self.agent.GetAction((0,0)) == 2)

			# Improve again, with negative reward for same action, so its no longer best
			pkt = ExpPacket([(0,0),(1,0)],[2, 2],[-10000])
			self.agent.ImprovePolicy(pkt)
			self.assertFalse(self.agent.GetAction((0,0)) == 2)
示例#6
0
		def setUp(self):
			self.sl = [2,3,4]
			self.al = [0,0,1]
			self.rl = [-1,-1]
			self.pckt_f = ExpPacket( self.sl, self.al, self.rl ) # filled packet
			self.pckt_e = ExpPacket()	# Empty packet
			self.pckt_max = ExpPacket( np.zeros(ExpPacket.MAX_SIZE),
										np.zeros(ExpPacket.MAX_SIZE),
										np.zeros(ExpPacket.MAX_SIZE) ) # Max packet
示例#7
0
	def GetLatestAsPacket(self, nS, nA, nR):
		"""
		Same as GetLatest, but returns a new ExpPacket with sublists instead
		
		Returns:
		[ExpPacket] if a valid one can be generated, [None] otherwise
		"""

		try:
			latest = self.GetLatest(nS,nA,nR)
		except IndexError:
			return None

		sl, al, rl = latest
		return ExpPacket(sl, al, rl)
示例#8
0
	def __init__(self, world, agents):
		"""
		Initializes an RLGame object, used to run episodes and train agents

		Params:
		- world: [World] the world that this game operates in (must be subclass of type World)
		- agents: [list] all the agents (initialized) participating in this game (must be list of type Agent)

		"""
		self._world = world
		self._agents = {}

		# Create dictionary of agents
		for agent in agents:
			self._agents[agent.GetID()] = agent

		self._history = {}
		
		# Build dictionary for history
		for id, _ in self._agents.items():
			self._history[id] = ExpPacket()

		self._episodes = []
示例#9
0
	class TestAgent(unittest.TestCase):

		def setUp(self):
			self.sl = [2,3,4]
			self.al = [0,0,1]
			self.rl = [-1,-1]
			self.pckt_f = ExpPacket( self.sl, self.al, self.rl ) # filled packet
			self.pckt_e = ExpPacket()	# Empty packet
			self.pckt_max = ExpPacket( np.zeros(ExpPacket.MAX_SIZE),
										np.zeros(ExpPacket.MAX_SIZE),
										np.zeros(ExpPacket.MAX_SIZE) ) # Max packet

		def test_InitAndLen(self):

			# Test that constructor with args works and Len is correct
			self.sl = [2,3,4]
			self.al = [0,0,1]
			self.rl = [-1,-1]
			_ = ExpPacket( self.sl, self.al, self.rl )
			self.assertTrue(self.pckt_f.LenS() == 3)
			self.assertTrue(self.pckt_f.LenA() == 3)
			self.assertTrue(self.pckt_f.LenR() == 2)

			# Test that constructor with no args works and Len is correct
			self.assertTrue(self.pckt_e.LenS() == 0)
			self.assertTrue(self.pckt_e.LenA() == 0)
			self.assertTrue(self.pckt_e.LenR() == 0)

			# Test that constructor accepts diff types of lists
			ExpPacket( [2,3,4], (0,0,1), np.array((-1,-1)) )

		def test_Push(self):

			# Test that we can push diff types of ints, floats and lists
			self.assertTrue(self.pckt_f.Push(2,6,1))
			self.assertTrue(self.pckt_f.Push(1.1,0.35,-2.32))
			self.assertTrue(self.pckt_f.Push( (1,3) , [0, 5.3] , np.array([-2,-2.21,-1]) ))

			# Make sure Len is still correct
			self.assertTrue(self.pckt_f.LenS() == 6)
			self.assertTrue(self.pckt_f.LenA() == 6)
			self.assertTrue(self.pckt_f.LenR() == 5)

			# Check that we get false when trying to push into a max packet
			self.assertFalse(self.pckt_max.Push(2,6,1))

		def test_Get(self):

			# Make sure we can get the list in packet correctly
			sl, al, rl = self.pckt_f.Get()
			self.assertTrue( np.all( sl == self.sl ) )
			self.assertTrue( np.all( al == self.al ) )
			self.assertTrue( np.all( rl == self.rl ) )

			# Push some new exp into packet
			self.pckt_f.Push(2,6,1)
			self.sl.append(2)
			self.al.append(6)
			self.rl.append(1)

			# Make sure we can get the list in packet correctly after pushing
			sl, al, rl = self.pckt_f.Get()
			self.assertTrue( np.all( sl == self.sl ) )
			self.assertTrue( np.all( al == self.al ) )
			self.assertTrue( np.all( rl == self.rl ) )

			# Make sure we can grab the latest packets properly
			sl, al, rl = self.pckt_f.GetLatest(2,2,2)
			self.assertTrue( np.all( sl == self.sl[-2:] ) )
			self.assertTrue( np.all( al == self.al[-2:] ) )
			self.assertTrue( np.all( rl == self.rl[-2:] ) )

			# Grab the curr length of the lists
			slen = self.pckt_f.LenS()
			alen = self.pckt_f.LenA()
			rlen = self.pckt_f.LenR()

			# Make sure GetLatest works with Len_() funcs
			sl, al, rl = self.pckt_f.GetLatest(slen, alen, rlen)

			# Make sure we get an IndexError if we try to access Latest > Num_Avail
			with self.assertRaises(IndexError):
				sl, al, rl = self.pckt_f.GetLatest(slen+1, alen, rlen)

			with self.assertRaises(IndexError):
				self.pckt_f.GetLatest(slen, alen+1, rlen)

			with self.assertRaises(IndexError):
				self.pckt_f.GetLatest(slen, alen, rlen+1)

		def test_IsReqDepth(self):

			# Make sure IsReqDepth reports true when indices <= max for each list
			self.assertTrue( self.pckt_f.IsReqDepth(1,1,1) )
			self.assertTrue( self.pckt_f.IsReqDepth(2,2,2) )
			self.assertTrue( self.pckt_f.IsReqDepth(3,3,2) )

			# Make sure IsReqDepth reports false if ANY index is > max for that list
			self.assertFalse( self.pckt_f.IsReqDepth(1,1,4) )
			self.assertFalse( self.pckt_f.IsReqDepth(1,6,1) )
			self.assertFalse( self.pckt_f.IsReqDepth(8,1,1) )

		def test_GetLatestAsPacket(self):

			# Make sure we get back an ExpPacket
			self.assertTrue( isinstance( self.pckt_f.GetLatestAsPacket(1,1,1), self.pckt_f.__class__))

			# Get the sub packet of slice of latest exp
			sub_packet = self.pckt_f.GetLatestAsPacket(1,2,1)
			self.assertTrue( sub_packet.LenS() == 1 )
			self.assertTrue( sub_packet.LenA() == 2 )
			self.assertTrue( sub_packet.LenR() == 1 )

			# Get the lists from the subpack and compare to sub_lists from original packet
			sl, al, rl = sub_packet.Get()
			sl_orig, al_orig, rl_orig = self.pckt_f.GetLatest(1,2,1)
			self.assertTrue( sl == sl_orig )
			self.assertTrue( al == al_orig )
			self.assertTrue( rl == rl_orig )

		def test_Print(self):
			try:
				print(self.pckt_f)
			except Exception as e:
				self.fail(f"Encountered error {e}")
示例#10
0
		def test_NotImplementedIsValidPacket(self):
			with self.assertRaises(NotImplementedError):
				self.policy.IsValidPacket( ExpPacket([], [], []) )
示例#11
0
		def test_NotImplementedImprovePolicy(self):
			with self.assertRaises(NotImplementedError):
				self.policy.ImprovePolicy( ExpPacket([], [], []) )
示例#12
0
		def test_NotImplementedGetTargetEstimate(self):
			with self.assertRaises(NotImplementedError):
				self.policy.GetTargetEstimate( ExpPacket([], [], []) )
			with self.assertRaises(NotImplementedError):
				self.policy.GetTargetEstimate( ExpPacket([(0,0),(3,1),(8,1)], [0,2,1], [-1,-1,0]) )