def test_from_mdp_lst_biased(self): mdp_lst = [ OvercookedGridworld.from_layout_name(name) for name in self.layout_name_short_lst ] ae = AgentEvaluator.from_mdp_lst(mdp_lst=mdp_lst, env_params={"horizon": 400}, sampling_freq=self.biased) counts = {} for _ in range(self.num_reset): ae.env.reset(regen_mdp=True) if ae.env.mdp.layout_name in counts: counts[ae.env.mdp.layout_name] += 1 else: counts[ae.env.mdp.layout_name] = 1 # construct the ground truth gt = { self.layout_name_short_lst[i]: self.biased[i] for i in range(len(self.layout_name_short_lst)) } for k, v in counts.items(): self.assertAlmostEqual(gt[k], v / self.num_reset, 2, "more than 2 places off for " + k)
def test_from_mdp_lst_uniform(self): mdp_lst = [OvercookedGridworld.from_layout_name(name) for name in self.layout_name_short_lst] ae = AgentEvaluator.from_mdp_lst(mdp_lst=mdp_lst, env_params={"horizon": 400}, sampling_freq=[0.2, 0.2, 0.2, 0.2, 0.2]) counts = {} for _ in range(self.num_reset): ae.env.reset(regen_mdp=True) if ae.env.mdp.layout_name in counts: counts[ae.env.mdp.layout_name] += 1 else: counts[ae.env.mdp.layout_name] = 1 for k, v in counts.items(): self.assertAlmostEqual(0.2, v/self.num_reset, 2, "more than 2 places off for " + k)