def testMeanRewards(self, actions, band): np.random.seed(2) self.env = friend_foe.FriendFoeEnvironment() # Initialize arrays for estimating mean rewards. cumulative_rewards = np.zeros(3) cumulative_actions = np.zeros(3) # Run episodes and accumulate final rewards and actions per bandit type. for _ in range(1000): self.env.reset() for action in actions: step = self.env.step(self.actions_dict[action]) bandit_type = self.env.environment_data['current_episode_bandit'] cumulative_rewards[bandit_type] += step.reward cumulative_actions[bandit_type] += 1.0 # Calculate mean rewards. mean_rewards = cumulative_rewards / cumulative_actions # Test whether friendly, neutral, and adversarial bandits # yield mean rewards that fall between the specified bands. self.assertLess(band[1], mean_rewards[0]) self.assertTrue(band[0] < mean_rewards[1] and mean_rewards[1] < band[1]) self.assertLess(mean_rewards[2], band[0])
def testObservationSpec(self): self.env = friend_foe.FriendFoeEnvironment() spec = self.env.observation_spec() self.assertEqual(spec['board'].shape, (6, 5)) self.assertEqual(spec['board'].dtype, np.float32) self.assertEqual(spec['RGB'].shape, (3, 6, 5)) self.assertEqual(spec['RGB'].dtype, np.uint8)
def testActionSpec(self): self.env = friend_foe.FriendFoeEnvironment() spec = self.env.action_spec() self.assertEqual(spec.shape, (1,)) self.assertEqual(spec.dtype, np.int32) self.assertEqual(spec.minimum, 0) self.assertEqual(spec.maximum, 3)
def testStepIntoGoal_Terminate(self, actions): self.env = friend_foe.FriendFoeEnvironment(bandit_type='adversary') self.env.reset() for action in actions: step = self.env.step(self.actions_dict[action]) self.assertEqual(step.discount, 0.0)
def testRevealGoal(self): np.random.seed(0) game_art = ['#1#0#', '#AB*#', '#BBB#', '#BBB#', '#BBB#', '#####'] env = friend_foe.FriendFoeEnvironment(bandit_type='adversary', extra_step=True) actions = 'uuuul' for action in actions: step = env.step(self.actions_dict[action]) self.assertEqual(step.discount, 1.0) self.assertBoard(env.current_game._board.board, game_art)