def create_exploration_runner(base_dir, create_agent_fn, schedule='continuous_train_and_eval'): """Creates an experiment Runner. Args: base_dir: Base directory for hosting all subdirectories. create_agent_fn: A function that takes as args a Tensorflow session and a Gym Atari 2600 environment, and returns an agent. schedule: string, which type of Runner to use. Returns: runner: A `run_experiment.Runner` like object. Raises: ValueError: When an unknown schedule is encountered. """ assert base_dir is not None # Continuously runs training and eval till max num_iterations is hit. if schedule == 'continuous_train_and_eval': return run_experiment.Runner(base_dir, create_agent_fn) # Continuously runs training till maximum num_iterations is hit. elif schedule == 'continuous_train': return run_experiment.TrainRunner(base_dir, create_agent_fn) else: raise ValueError('Unknown schedule: {}'.format(schedule))
def testRunOnePhase(self): max_steps = 10 environment_steps = 2 environment = MockEnvironment(max_steps=environment_steps) statistics = [] runner = run_experiment.Runner(self._test_subdir, self._create_agent_fn, lambda: environment) step_number, sum_returns, num_episodes = runner._run_one_phase( max_steps, statistics, 'test') calls_to_run_episode = int(max_steps / environment_steps) self.assertEqual(self._agent.step.call_count, calls_to_run_episode) self.assertEqual(self._agent.end_episode.call_count, calls_to_run_episode) self.assertEqual(max_steps, step_number) self.assertEqual(-1 * calls_to_run_episode, sum_returns) self.assertEqual(calls_to_run_episode, num_episodes) expected_statistics = [] for _ in range(calls_to_run_episode): expected_statistics.append({ 'test_episode_lengths': 2, 'test_episode_returns': -1 }) self.assertEqual(len(expected_statistics), len(statistics)) for i in range(len(statistics)): self.assertDictEqual(expected_statistics[i], statistics[i])
def testRunOneIteration(self): environment_steps = 2 environment = MockEnvironment(max_steps=environment_steps) training_steps = 20 evaluation_steps = 10 runner = run_experiment.Runner(self._test_subdir, self._create_agent_fn, lambda: environment, training_steps=training_steps, evaluation_steps=evaluation_steps) dictionary = runner._run_one_iteration(1) train_calls = int(training_steps / environment_steps) eval_calls = int(evaluation_steps / environment_steps) expected_dictionary = { 'train_episode_lengths': [2 for _ in range(train_calls)], 'train_episode_returns': [-1 for _ in range(train_calls)], 'train_average_return': [-1], 'eval_episode_lengths': [2 for _ in range(eval_calls)], 'eval_episode_returns': [-1 for _ in range(eval_calls)], 'eval_average_return': [-1] } for k in expected_dictionary: self.assertEqual(expected_dictionary[k], dictionary[k]) # Also verify that average number of steps per second is present and # positive. self.assertEqual(len(dictionary['train_average_steps_per_second']), 1) self.assertGreater(dictionary['train_average_steps_per_second'][0], 0)
def create_runner(base_dir, create_agent_fn, schedule='continuous_train_and_eval'): """Creates an experiment Runner. TODO(b/): Figure out the right idiom to create a Runner. The current mechanism of using a number of flags will not scale and is not elegant. Args: base_dir: Base directory for hosting all subdirectories. create_agent_fn: A function that takes as args a Tensorflow session and a Gym Atari 2600 environment, and returns an agent. schedule: string, which type of Runner to use. Returns: runner: A `run_experiment.Runner` like object. Raises: ValueError: When an unknown schedule is encountered. """ assert base_dir is not None # Continuously runs training and eval till max num_iterations is hit. if schedule == 'continuous_train_and_eval': return run_experiment.Runner(base_dir, create_agent_fn, atari_lib.create_atari_environment) # Continuously runs training till maximum num_iterations is hit. elif schedule == 'continuous_train': return run_experiment.TrainRunner(base_dir, create_agent_fn, atari_lib.create_atari_environment) else: raise ValueError('Unknown schedule: {}'.format(schedule))
def testCheckpointExperiment(self, mock_logger_constructor, mock_checkpointer_constructor): checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints') test_dict = {'test': 1} iteration = 1729 def bundle_and_checkpoint(x, y): self.assertEqual(checkpoint_dir, x) self.assertEqual(iteration, y) return test_dict self._agent.bundle_and_checkpoint.side_effect = bundle_and_checkpoint experiment_checkpointer = mock.Mock() mock_checkpointer_constructor.return_value = experiment_checkpointer logs_data = {'one': 1, 'two': 2} mock_logger = MockLogger(run_asserts=False, data=logs_data) mock_logger_constructor.return_value = mock_logger runner = run_experiment.Runner(self._test_subdir, self._create_agent_fn, mock.Mock) runner._checkpoint_experiment(iteration) self.assertEqual(1, experiment_checkpointer.save_checkpoint.call_count) mock_args, _ = experiment_checkpointer.save_checkpoint.call_args self.assertEqual(iteration, mock_args[0]) test_dict['logs'] = logs_data test_dict['current_iteration'] = iteration self.assertDictEqual(test_dict, mock_args[1])
def testRunOneEpisodeWithLowMaxSteps(self): max_steps_per_episode = 2 environment = MockEnvironment() runner = run_experiment.Runner( self._test_subdir, self.__fn, lambda: environment, max_steps_per_episode=max_steps_per_episode) step_number, total_reward = runner._run_one_episode() self.assertEqual(self._agent.step.call_count, max_steps_per_episode - 1) self.assertEqual(self._agent.end_episode.call_count, 1) self.assertEqual(max_steps_per_episode, step_number) self.assertEqual(-1, total_reward)
def testRunOneEpisode(self): max_steps_per_episode = 11 environment = MockEnvironment() runner = run_experiment.Runner( self._test_subdir, self.__fn, lambda: environment, max_steps_per_episode=max_steps_per_episode) step_number, total_reward = runner._run_one_episode() self.assertEqual(self._agent.step.call_count, environment.max_steps - 1) self.assertEqual(self._agent.end_episode.call_count, 1) self.assertEqual(environment.max_steps, step_number) # Expected reward will be \sum_{i=0}^{9} (-1)**i * i = -5 self.assertEqual(-5, total_reward)
def testRunExperimentWithInconsistentRange(self, mock_logger_constructor, mock_checkpointer_constructor): experiment_logger = MockLogger() mock_logger_constructor.return_value = experiment_logger experiment_checkpointer = mock.Mock() mock_checkpointer_constructor.return_value = experiment_checkpointer runner = run_experiment.Runner( self._test_subdir, self.__fn, mock.Mock, num_iterations=0) runner.run_experiment() self.assertEqual(0, experiment_checkpointer.save_checkpoint.call_count) self.assertEqual(0, experiment_logger._calls_to_set) self.assertEqual(0, experiment_logger._calls_to_log)
def testDefaultGinDqn(self): """Test DQNAgent configuration using the default gin config.""" tf.logging.info('####### Training the DQN agent #####') tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/dqn/configs/dqn.gin'] gin_bindings = [ 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'dqn'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer) self.assertNear(0.00025, runner._agent.optimizer._learning_rate, 0.0001) shutil.rmtree(self._base_dir)
def testDefaultGinRainbow(self): """Test RainbowAgent default configuration using default gin.""" tf.logging.info('####### Training the RAINBOW agent #####') tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/rainbow/configs/rainbow.gin'] gin_bindings = [ 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'rainbow'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer) self.assertNear(0.0000625, runner._agent.optimizer._lr, 0.0001) shutil.rmtree(self._base_dir)
def testLogExperiment(self, mock_logger_constructor): log_every_n = 2 logging_file_prefix = 'prefix' statistics = 'statistics' experiment_logger = MockLogger(test_cls=self) mock_logger_constructor.return_value = experiment_logger runner = run_experiment.Runner( self._test_subdir, self.__fn, mock.Mock, logging_file_prefix=logging_file_prefix, log_every_n=log_every_n) num_iterations = 10 for i in range(num_iterations): runner._log_experiment(i, statistics) self.assertEqual(num_iterations, experiment_logger._calls_to_set) self.assertEqual((num_iterations / log_every_n), experiment_logger._calls_to_log)
def testRewardClipping(self, reward_clipping, reward, expected_reward): environment = tf.test.mock.Mock() environment.step.return_value = (0, reward, True, {}) mock_agent = tf.test.mock.Mock() agent_fn = tf.test.mock.MagicMock(return_value=mock_agent) runner = run_experiment.Runner(self.get_temp_dir(), agent_fn, lambda: environment, log_every_n=1, num_iterations=1, training_steps=1, evaluation_steps=0, reward_clipping=reward_clipping) runner._checkpoint_experiment = tf.test.mock.Mock() runner._log_experiment = tf.test.mock.Mock() runner.run_experiment() mock_agent.end_episode.assert_called_once_with(expected_reward)
def testDefaultGinImplicitQuantileIcml(self): """Test default ImplicitQuantile configuration using ICML gin.""" tf.logging.info('###### Training the Implicit Quantile agent #####') self._base_dir = os.path.join( '/tmp/dopamine_tests', datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S')) tf.logging.info('###### IQN base dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/' 'implicit_quantile/configs/implicit_quantile_icml.gin'] gin_bindings = [ 'Runner.num_iterations=0', ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertEqual(1000000, runner._agent._replay.memory._replay_capacity) shutil.rmtree(self._base_dir)
def testOverrideGinDqn(self): """Test DQNAgent configuration overridden with AdamOptimizer.""" tf.logging.info('####### Training the DQN agent #####') tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/dqn/configs/dqn.gin'] gin_bindings = [ 'DQNAgent.optimizer = @tf.train.AdamOptimizer()', 'tf.train.AdamOptimizer.learning_rate = 100', 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'dqn'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer) self.assertEqual(100, runner._agent.optimizer._lr) shutil.rmtree(self._base_dir)
def testOverrideGinRainbow(self): """Test RainbowAgent configuration overridden with RMSPropOptimizer.""" tf.logging.info('####### Training the RAINBOW agent #####') tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir)) gin_files = [ 'dopamine/agents/rainbow/configs/rainbow.gin', ] gin_bindings = [ 'RainbowAgent.optimizer = @tf.train.RMSPropOptimizer()', 'tf.train.RMSPropOptimizer.learning_rate = 100', 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'rainbow'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer) self.assertEqual(100, runner._agent.optimizer._learning_rate) shutil.rmtree(self._base_dir)
def testRunExperiment(self, mock_logger_constructor, mock_checkpointer_constructor, mock_get_latest): log_every_n = 1 environment = MockEnvironment() experiment_logger = MockLogger(run_asserts=False) mock_logger_constructor.return_value = experiment_logger experiment_checkpointer = mock.Mock() start_iteration = 1729 mock_get_latest.return_value = start_iteration def load_checkpoint(_): return { 'logs': 'log_data', 'current_iteration': start_iteration - 1 } experiment_checkpointer.load_checkpoint.side_effect = load_checkpoint mock_checkpointer_constructor.return_value = experiment_checkpointer def bundle_and_checkpoint(x, y): del x, y # Unused. return {'test': 1} self._agent.bundle_and_checkpoint.side_effect = bundle_and_checkpoint num_iterations = 10 self._agent.unbundle.return_value = True end_iteration = start_iteration + num_iterations runner = run_experiment.Runner(self._test_subdir, self._create_agent_fn, lambda: environment, log_every_n=log_every_n, num_iterations=end_iteration, training_steps=1, evaluation_steps=1) self.assertEqual(start_iteration, runner._start_iteration) runner.run_experiment() self.assertEqual(num_iterations, experiment_checkpointer.save_checkpoint.call_count) self.assertEqual(num_iterations, experiment_logger._calls_to_set) self.assertEqual(num_iterations, experiment_logger._calls_to_log) glob_string = '{}/events.out.tfevents.*'.format(self._test_subdir) self.assertGreater(len(tf.gfile.Glob(glob_string)), 0)
def create_runner(self, env_fn, hparams, target_iterations, training_steps_per_iteration): # pylint: disable=unbalanced-tuple-unpacking agent_params, optimizer_params, \ runner_params, replay_buffer_params = _parse_hparams(hparams) # pylint: enable=unbalanced-tuple-unpacking optimizer = _get_optimizer(optimizer_params) agent_params["optimizer"] = optimizer agent_params.update(replay_buffer_params) create_agent_fn = get_create_agent(agent_params) runner = run_experiment.Runner( base_dir=self.agent_model_dir, create_agent_fn=create_agent_fn, create_environment_fn=get_create_env_fun( env_fn, time_limit=hparams.time_limit), evaluation_steps=0, num_iterations=target_iterations, training_steps=training_steps_per_iteration, **runner_params) return runner
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ logging.set_verbosity(logging.INFO) tf.compat.v1.disable_v2_behavior() base_dir = FLAGS.base_dir gin_files, gin_bindings = FLAGS.gin_files, FLAGS.gin_bindings run_experiment.load_gin_configs(gin_files, gin_bindings) # Set the Jax agent seed using the run number create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number) if FLAGS.max_episode_eval: runner_fn = eval_run_experiment.MaxEpisodeEvalRunner logging.info('Using MaxEpisodeEvalRunner for evaluation.') runner = runner_fn(base_dir, create_agent_fn) else: runner = run_experiment.Runner(base_dir, create_agent_fn) runner.run_experiment()
def testRunOneIteration(self): environment_steps = 2 environment = MockEnvironment(max_steps=environment_steps) training_steps = 20 evaluation_steps = 10 runner = run_experiment.Runner( self._test_subdir, self.__fn, lambda: environment, training_steps=training_steps, evaluation_steps=evaluation_steps) dictionary = runner._run_one_iteration(1) train_calls = int(training_steps / environment_steps) eval_calls = int(evaluation_steps / environment_steps) expected_dictionary = { 'train_episode_lengths': [2 for _ in range(train_calls)], 'train_episode_returns': [-1 for _ in range(train_calls)], 'train_average_return': [-1], 'eval_episode_lengths': [2 for _ in range(eval_calls)], 'eval_episode_returns': [-1 for _ in range(eval_calls)], 'eval_average_return': [-1] } self.assertDictEqual(expected_dictionary, dictionary)
def testInitializeCheckpointingWhenCheckpointUnbundleSucceeds( self, mock_get_latest): latest_checkpoint = 7 mock_get_latest.return_value = latest_checkpoint logs_data = {'a': 1, 'b': 2} current_iteration = 1729 checkpoint_data = {'current_iteration': current_iteration, 'logs': logs_data} checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints') checkpoint = checkpointer.Checkpointer(checkpoint_dir, 'ckpt') checkpoint.save_checkpoint(latest_checkpoint, checkpoint_data) mock_agent = mock.Mock() mock_agent.unbundle.return_value = True runner = run_experiment.Runner(self._test_subdir, lambda x, y, summary_writer: mock_agent, mock.Mock) expected_iteration = current_iteration + 1 self.assertEqual(expected_iteration, runner._start_iteration) self.assertDictEqual(logs_data, runner._logger.data) mock_agent.unbundle.assert_called_once_with( checkpoint_dir, latest_checkpoint, checkpoint_data)
def dopamine_train(base_dir, hidden_layer_size, gamma, learning_rate, num_train_steps, network='chain'): """Train an agent using dopamine.""" runner = run_experiment.Runner(base_dir, functools.partial( _create_agent, hidden_layer_size=hidden_layer_size, gamma=gamma, learning_rate=learning_rate), functools.partial(_create_environment, network=network), num_iterations=num_train_steps, training_steps=500, evaluation_steps=100, max_steps_per_episode=20) runner.run_experiment() return runner
def testOverrideGinImplicitQuantile(self): """Test ImplicitQuantile configuration overriding using IQN gin.""" logging.info('###### Training the Implicit Quantile agent #####') self._base_dir = os.path.join( '/tmp/dopamine_tests', datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S')) logging.info('###### IQN base dir: %s', self._base_dir) gin_files = [ 'dopamine/agents/implicit_quantile/configs/' 'implicit_quantile.gin' ] gin_bindings = [ 'Runner.num_iterations=0', 'WrappedPrioritizedReplayBuffer.replay_capacity = 1000', ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertEqual(1000, runner._agent._replay.memory._replay_capacity) shutil.rmtree(self._base_dir)
def testInitializeCheckpointingWhenCheckpointUnbundleFails( self, mock_logger_constructor, mock_checkpointer_constructor, mock_get_latest): mock_checkpointer = _create_mock_checkpointer() mock_checkpointer_constructor.return_value = mock_checkpointer latest_checkpoint = 7 mock_get_latest.return_value = latest_checkpoint agent = mock.Mock() agent.unbundle.return_value = False mock_logger = mock.Mock() mock_logger_constructor.return_value = mock_logger runner = run_experiment.Runner(self._test_subdir, lambda x, y, summary_writer: agent, mock.Mock) self.assertEqual(0, runner._start_iteration) self.assertEqual(1, mock_checkpointer.load_checkpoint.call_count) self.assertEqual(1, agent.unbundle.call_count) mock_args, _ = agent.unbundle.call_args self.assertEqual('{}/checkpoints'.format(self._test_subdir), mock_args[0]) self.assertEqual(latest_checkpoint, mock_args[1]) expected_dictionary = {'current_iteration': 1729, 'logs': 'logs'} self.assertDictEqual(expected_dictionary, mock_args[2])
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) runner = run_experiment.Runner(FLAGS.base_dir, create_agent) runner.run_experiment()
def make_runner(gin_files): run_experiment.load_gin_configs(gin_files, None) runner = run_experiment.Runner(base_dir=base_dir, create_agent_fn=create_agent_fn, create_environment_fn=create_env_fn) return runner
create_agent.agent_name = 'ac' run_experiment.Runner.create_environment_fn = @gym_lib.create_gym_environment run_experiment.Runner.num_iterations = 500 run_experiment.Runner.training_steps = 1000 run_experiment.Runner.evaluation_steps = 1000 run_experiment.Runner.max_steps_per_episode = 200 circular_replay_buffer.WrappedReplayBuffer.replay_capacity = 50000 circular_replay_buffer.WrappedReplayBuffer.batch_size = 128 ACAgent.actor_network = @CartpoleActorNetwork ACAgent.critic_network = @CartpoleCriticNetwork ACAgent.observation_shape = %gym_lib.CARTPOLE_OBSERVATION_SHAPE ACAgent.observation_dtype = %gym_lib.CARTPOLE_OBSERVATION_DTYPE ACAgent.stack_size = %gym_lib.CARTPOLE_STACK_SIZE tf.train.AdamOptimizer.learning_rate = 0.00001 tf.train.AdamOptimizer.epsilon = 0.00000390625 """ tf.logging.set_verbosity(tf.logging.INFO) gin.parse_config(ac_config, skip_unknown=False) ac_runner = run_experiment.Runner(LOG_PATH, create_ac_agent) print('Will train ac agent, please be patient, may be a while...') ac_runner.run_experiment() print('Done training!')
def run(bsuite_id: str) -> str: """Runs Dopamine DQN on a given bsuite environment, logging to CSV.""" raw_env = bsuite.load_and_record( bsuite_id=bsuite_id, save_path=FLAGS.save_path, logging_mode=FLAGS.logging_mode, overwrite=FLAGS.overwrite, ) class Network(tf.keras.Model): """Build deep network compatible with dopamine/discrete_domains/gym_lib.""" def __init__(self, num_actions: int, name='Network'): super(Network, self).__init__(name=name) self.forward_fn = tf.keras.Sequential( [tf.keras.layers.Flatten()] + [ tf.keras.layers.Dense(FLAGS.num_units, activation=tf.keras.activations.relu) for _ in range(FLAGS.num_hidden_layers) ] + [tf.keras.layers.Dense(num_actions, activation=None)]) def call(self, state): """Creates the output tensor/op given the state tensor as input.""" x = tf.cast(state, tf.float32) x = self.forward_fn(x) return atari_lib.DQNNetworkType(x) def create_agent(sess: tf.Session, environment: gym.Env, summary_writer=None): """Factory method for agent initialization in Dopmamine.""" del summary_writer return dqn_agent.DQNAgent( sess=sess, num_actions=environment.action_space.n, observation_shape=OBSERVATION_SHAPE, observation_dtype=tf.float32, stack_size=1, network=Network, gamma=FLAGS.agent_discount, update_horizon=1, min_replay_history=FLAGS.min_replay_size, update_period=FLAGS.sgd_period, target_update_period=FLAGS.target_update_period, epsilon_decay_period=FLAGS.epsilon_decay_period, epsilon_train=FLAGS.epsilon, optimizer=tf.train.AdamOptimizer(FLAGS.learning_rate), ) def create_environment() -> gym.Env: """Factory method for environment initialization in Dopmamine.""" env = wrappers.ImageObservation(raw_env, OBSERVATION_SHAPE) if FLAGS.verbose: env = terminal_logging.wrap_environment(env, log_every=True) # pytype: disable=wrong-arg-types env = gym_wrapper.GymFromDMEnv(env) env.game_over = False # Dopamine looks for this return env runner = run_experiment.Runner( base_dir=FLAGS.base_dir, create_agent_fn=create_agent, create_environment_fn=create_environment, ) num_episodes = FLAGS.num_episodes or getattr(raw_env, 'bsuite_num_episodes') for _ in range(num_episodes): runner._run_one_episode() # pylint: disable=protected-access return bsuite_id
env_path = '../env/AnimalAI' worker_id = random.randint(1, 100) arena_config_in = ArenaConfig('configs/1-Food.yaml') base_dir = 'models/dopamine' gin_files = ['configs/rainbow.gin'] def create_env_fn(): env = AnimalAIEnv(environment_filename=env_path, worker_id=worker_id, n_arenas=1, arenas_configurations=arena_config_in, docker_training=False, retro=True) print("all good create_env_fn") return env def create_agent_fn(sess, env, summary_writer): return rainbow_agent.RainbowAgent(sess=sess, num_actions=env.action_space.n, summary_writer=summary_writer) run_experiment.load_gin_configs(gin_files, None) runner = run_experiment.Runner(base_dir=base_dir, create_agent_fn=create_agent_fn, create_environment_fn=create_env_fn) # runner.run_experiment()
def main(unused_argv): config.update('jax_disable_jit', FLAGS.disable_jit) logging.set_verbosity(logging.INFO) run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) runner = run_experiment.Runner(FLAGS.base_dir, create_agent) runner.run_experiment()
def testInitializeCheckpointingWithNoCheckpointFile(self, mock_get_latest): mock_get_latest.return_value = -1 base_dir = '/does/not/exist' with self.assertRaisesRegexp(tf.errors.PermissionDeniedError, '.*/does.*'): run_experiment.Runner(base_dir, self._create_agent_fn, mock.Mock)