def test_meta_evaluator(): set_seed(100) tasks = SetTaskSampler(PointEnv, wrapper=set_length) max_episode_length = 200 with tempfile.TemporaryDirectory() as log_dir_name: trainer = Trainer( SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='last', snapshot_gap=1)) env = PointEnv(max_episode_length=max_episode_length) algo = OptimalActionInference(env=env, max_episode_length=max_episode_length) trainer.setup(algo, env) meta_eval = MetaEvaluator(test_task_sampler=tasks, n_test_tasks=10) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) logger.log(tabular) meta_eval.evaluate(algo) logger.log(tabular) logger.dump_output_type(CsvOutput) logger.remove_output_type(CsvOutput) with open(log_file.name, 'r') as file: rows = list(csv.DictReader(file)) assert len(rows) == 2 assert float( rows[0]['MetaTest/__unnamed_task__/TerminationRate']) < 1.0 assert float(rows[0]['MetaTest/__unnamed_task__/Iteration']) == 0 assert (float(rows[0]['MetaTest/__unnamed_task__/MaxReturn']) >= float( rows[0]['MetaTest/__unnamed_task__/AverageReturn'])) assert (float(rows[0]['MetaTest/__unnamed_task__/AverageReturn']) >= float(rows[0]['MetaTest/__unnamed_task__/MinReturn'])) assert float(rows[1]['MetaTest/__unnamed_task__/Iteration']) == 1
def test_pickle_meta_evaluator(): set_seed(100) tasks = SetTaskSampler(lambda: GarageEnv(PointEnv())) max_path_length = 200 env = GarageEnv(PointEnv()) n_traj = 3 with tempfile.TemporaryDirectory() as log_dir_name: runner = LocalRunner( SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='last', snapshot_gap=1)) meta_eval = MetaEvaluator(test_task_sampler=tasks, max_path_length=max_path_length, n_test_tasks=10, n_exploration_traj=n_traj) policy = RandomPolicy(env.spec.action_space) algo = MockAlgo(env, policy, max_path_length, n_traj, meta_eval) runner.setup(algo, env) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) meta_eval_pickle = cloudpickle.dumps(meta_eval) meta_eval2 = cloudpickle.loads(meta_eval_pickle) meta_eval2.evaluate(algo)
def test_meta_evaluator_with_tf(): set_seed(100) tasks = SetTaskSampler(PointEnv, wrapper=set_length) max_episode_length = 200 env = PointEnv() n_eps = 3 with tempfile.TemporaryDirectory() as log_dir_name: ctxt = SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='none', snapshot_gap=1) with TFTrainer(ctxt) as trainer: meta_eval = MetaEvaluator(test_task_sampler=tasks, n_test_tasks=10, n_exploration_eps=n_eps) policy = GaussianMLPPolicy(env.spec) algo = MockAlgo(env, policy, max_episode_length, n_eps, meta_eval) trainer.setup(algo, env) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) algo_pickle = cloudpickle.dumps(algo) tf.compat.v1.reset_default_graph() with TFTrainer(ctxt) as trainer: algo2 = cloudpickle.loads(algo_pickle) trainer.setup(algo2, env) trainer.train(10, 0)
def test_meta_evaluator_with_tf(): set_seed(100) tasks = SetTaskSampler(lambda: GarageEnv(PointEnv())) max_path_length = 200 env = GarageEnv(PointEnv()) n_traj = 3 with tempfile.TemporaryDirectory() as log_dir_name: ctxt = SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='none', snapshot_gap=1) with LocalTFRunner(ctxt) as runner: meta_eval = MetaEvaluator(test_task_sampler=tasks, max_path_length=max_path_length, n_test_tasks=10, n_exploration_traj=n_traj) policy = GaussianMLPPolicy(env.spec) algo = MockAlgo(env, policy, max_path_length, n_traj, meta_eval) runner.setup(algo, env) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) algo_pickle = cloudpickle.dumps(algo) with tf.Graph().as_default(): with LocalTFRunner(ctxt) as runner: algo2 = cloudpickle.loads(algo_pickle) runner.setup(algo2, env) runner.train(10, 0)
def test_one_folder(self, meta_train_dir, itrs): snapshot_config = SnapshotConfig(snapshot_dir=meta_train_dir, snapshot_mode='all', snapshot_gap=1) runner = LocalRunner(snapshot_config=snapshot_config) meta_sampler = AllSetTaskSampler(self.meta_task_cls) runner.restore(meta_train_dir) meta_evaluator = MetaEvaluator( runner, test_task_sampler=meta_sampler, max_path_length=self.max_path_length, n_test_tasks=meta_sampler.n_tasks, n_exploration_traj=self.adapt_rollout_per_task, prefix='') for itr in itrs: log_filename = os.path.join(meta_train_dir, 'meta-test-itr_{}.csv'.format(itr)) logger.add_output(CsvOutput(log_filename)) logger.log("Writing into {}".format(log_filename)) runner.restore(meta_train_dir, from_epoch=itr) meta_evaluator.evaluate(runner._algo, self.test_rollout_per_task) tabular.record('Iteration', runner._stats.total_epoch) tabular.record('TotalEnvSteps', runner._stats.total_env_steps) logger.log(tabular) logger.dump_output_type(CsvOutput) logger.remove_output_type(CsvOutput)
def test_meta_evaluator_n_traj(): set_seed(100) tasks = SetTaskSampler(PointEnv) max_path_length = 200 env = MetaRLEnv(PointEnv()) n_traj = 3 with tempfile.TemporaryDirectory() as log_dir_name: runner = LocalRunner( SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='last', snapshot_gap=1)) algo = MockAlgo(env, max_path_length, n_traj) runner.setup(algo, env) meta_eval = MetaEvaluator(runner, test_task_sampler=tasks, max_path_length=max_path_length, n_test_tasks=10, n_exploration_traj=n_traj) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo)
class TestCsvOutput: def setup_method(self): self.log_file = tempfile.NamedTemporaryFile() self.csv_output = CsvOutput(self.log_file.name) self.tabular = TabularInput() self.tabular.clear() def teardown_method(self): self.log_file.close() def test_record(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo), 'bar': str(bar)}, {'foo': str(foo * 2), 'bar': str(bar * 2)}, ] # yapf: disable self.assert_csv_matches(correct) def test_record_inconsistent(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) with pytest.warns(CsvOutputWarning): self.csv_output.record(self.tabular) # this should not produce a warning, because we only warn once self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo)}, {'foo': str(foo * 2)}, ] # yapf: disable self.assert_csv_matches(correct) def test_empty_record(self): self.csv_output.record(self.tabular) assert not self.csv_output._writer foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) assert not self.csv_output._warned_once def test_unacceptable_type(self): with pytest.raises(ValueError): self.csv_output.record('foo') def test_disable_warnings(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.disable_warnings() # this should not produce a warning, because we disabled warnings self.csv_output.record(self.tabular) def assert_csv_matches(self, correct): """Check the first row of a csv file and compare it to known values.""" with open(self.log_file.name, 'r') as file: reader = csv.DictReader(file) for correct_row in correct: row = next(reader) assert row == correct_row
def setup_method(self): self.log_file = tempfile.NamedTemporaryFile() self.csv_output = CsvOutput(self.log_file.name) self.tabular = TabularInput() self.tabular.clear()
class TestCsvOutput: def setup_method(self): self.log_file = tempfile.NamedTemporaryFile() self.csv_output = CsvOutput(self.log_file.name) self.tabular = TabularInput() self.tabular.clear() def teardown_method(self): self.log_file.close() def test_record(self): foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) self.tabular.record('foo', foo * 2) self.tabular.record('bar', bar * 2) self.csv_output.record(self.tabular) self.csv_output.dump() correct = [ {'foo': str(foo), 'bar': str(bar)}, {'foo': str(foo * 2), 'bar': str(bar * 2)}, ] # yapf: disable self.assert_csv_matches(correct) assert not os.path.exists('{}.tmp'.format(self.log_file.name)) def test_key_inconsistent(self): for i in range(4): self.tabular.record('itr', i) self.tabular.record('loss', 100.0 / (2 + i)) # the addition of new data to tabular breaks logging to CSV if i > 0: self.tabular.record('x', i) if i > 1: self.tabular.record('y', i + 1) # this should not produce a warning, because we only warn once self.csv_output.record(self.tabular) self.csv_output.dump() correct = [{ 'itr': str(0), 'loss': str(100.0 / 2.), 'x': '', 'y': '' }, { 'itr': str(1), 'loss': str(100.0 / 3.), 'x': str(1), 'y': '' }, { 'itr': str(2), 'loss': str(100.0 / 4.), 'x': str(2), 'y': str(3) }, { 'itr': str(3), 'loss': str(100.0 / 5.), 'x': str(3), 'y': str(4) }] self.assert_csv_matches(correct) def test_empty_record(self): self.csv_output.record(self.tabular) self.csv_output.dump() foo = 1 bar = 10 self.tabular.record('foo', foo) self.tabular.record('bar', bar) self.csv_output.record(self.tabular) self.csv_output.dump() # Empty lines are not recorded self.assert_csv_matches([{'foo': str(foo), 'bar': str(bar)}]) def test_unacceptable_type(self): with pytest.raises(ValueError): self.csv_output.record('foo') def assert_csv_matches(self, correct): """Check the first row of a csv file and compare it to known values.""" with open(self.log_file.name, 'r') as file: contents = list(csv.DictReader(file)) assert len(contents) == len(correct) for row, correct_row in zip(contents, correct): assert sorted(list(row.items())) == sorted( list(correct_row.items()))
def setup_method(self): self.log_file = tempfile.NamedTemporaryFile() self.csv_output = CsvOutput(self.log_file.name)