def test_process_rewards_info_logs_kwargs_partial(self): task = OnlineOnPolicyEvalTask(time=False) learner = RecordingLearner(with_info=True, with_log=True) interactions = [ SimulatedInteraction(1,[1,2,3],rewards=[7,8,9]), SimulatedInteraction(2,[4,5,6],rewards=[4,5,6],letters=['d','e','f']), SimulatedInteraction(3,[7,8,9],rewards=[1,2,3],letters=['g','h','i']), ] task_results = list(task.process(learner, interactions)) expected_predict_calls = [(1,[1,2,3]),(2,[4,5,6]),(3,[7,8,9])] expected_predict_returns = [([1,0,0],1),([0,1,0],2),([0,0,1],3)] expected_learn_calls = [(1,1,7,1,1),(2,5,5,1,2),(3,9,3,1,3)] expected_task_results = [ {"rewards":7,'learn':1,'predict':1}, {"rewards":5,'learn':2,'predict':2,'letters':'e'}, {"rewards":3,'learn':3,'predict':3,'letters':'i'} ] self.assertEqual(expected_predict_calls, learner.predict_calls) self.assertEqual(expected_predict_returns, learner.predict_returns) self.assertEqual(expected_learn_calls, learner.learn_calls) self.assertEqual(expected_task_results, task_results)
def test_time(self): task = OnlineOnPolicyEvalTask(time=True) learner = RecordingLearner() interactions = [SimulatedInteraction(1,[1,2,3],rewards=[7,8,9])] task_results = list(task.process(learner, interactions)) self.assertAlmostEqual(0, task_results[0]["predict_time"], places=2) self.assertAlmostEqual(0, task_results[0]["learn_time" ], places=2)
def test_process_sparse_rewards_no_info_no_logs_no_kwargs(self): task = OnlineOnPolicyEvalTask(time=False) learner = RecordingLearner(with_info=False, with_log=False) interactions = [ SimulatedInteraction({'c':1},[{'a':1},{'a':2}],rewards=[7,8]), SimulatedInteraction({'c':2},[{'a':4},{'a':5}],rewards=[4,5]), ] task_results = list(task.process(learner, interactions)) expected_predict_calls = [({'c':1},[{'a':1},{'a':2}]),({'c':2},[{'a':4},{'a':5}])] expected_predict_returns = [[1,0],[0,1]] expected_learn_calls = [({'c':1},{'a':1},7,1,None),({'c':2},{'a':5},5,1,None)] expected_task_results = [{"rewards":7},{"rewards":5}] self.assertEqual(expected_predict_calls, learner.predict_calls) self.assertEqual(expected_predict_returns, learner.predict_returns) self.assertEqual(expected_learn_calls, learner.learn_calls) self.assertEqual(expected_task_results, task_results)
def test_process_reveals_rewards_no_info_no_logs_no_kwargs(self): task = OnlineOnPolicyEvalTask(time=False) learner = RecordingLearner(with_info=False, with_log=False) interactions = [ SimulatedInteraction(1,[1,2,3],reveals=[7,8,9],rewards=[1,3,5]), SimulatedInteraction(2,[4,5,6],reveals=[4,5,6],rewards=[2,4,6]), SimulatedInteraction(3,[7,8,9],reveals=[1,2,3],rewards=[3,5,7]), ] task_results = list(task.process(learner, interactions)) expected_predict_calls = [(1,[1,2,3]),(2,[4,5,6]),(3,[7,8,9])] expected_predict_returns = [[1,0,0],[0,1,0],[0,0,1]] expected_learn_calls = [(1,1,7,1,None),(2,5,5,1,None),(3,9,3,1,None)] expected_task_results = [{"reveals":7,"rewards":1},{"reveals":5,"rewards":4}, {"reveals":3,"rewards":7}] self.assertEqual(expected_predict_calls, learner.predict_calls) self.assertEqual(expected_predict_returns, learner.predict_returns) self.assertEqual(expected_learn_calls, learner.learn_calls) self.assertEqual(expected_task_results, task_results)