def create_exploration_runner(base_dir,
                              create_agent_fn,
                              schedule='continuous_train_and_eval'):
    """Creates an experiment Runner.

  Args:
    base_dir: Base directory for hosting all subdirectories.
    create_agent_fn: A function that takes as args a Tensorflow session and a
     Gym Atari 2600 environment, and returns an agent.
    schedule: string, which type of Runner to use.

  Returns:
    runner: A `run_experiment.Runner` like object.

  Raises:
    ValueError: When an unknown schedule is encountered.
  """
    assert base_dir is not None
    # Continuously runs training and eval till max num_iterations is hit.
    if schedule == 'continuous_train_and_eval':
        return run_experiment.Runner(base_dir, create_agent_fn)
    # Continuously runs training till maximum num_iterations is hit.
    elif schedule == 'continuous_train':
        return run_experiment.TrainRunner(base_dir, create_agent_fn)
    else:
        raise ValueError('Unknown schedule: {}'.format(schedule))
示例#2
0
 def testRunOnePhase(self):
     max_steps = 10
     environment_steps = 2
     environment = MockEnvironment(max_steps=environment_steps)
     statistics = []
     runner = run_experiment.Runner(self._test_subdir,
                                    self._create_agent_fn,
                                    lambda: environment)
     step_number, sum_returns, num_episodes = runner._run_one_phase(
         max_steps, statistics, 'test')
     calls_to_run_episode = int(max_steps / environment_steps)
     self.assertEqual(self._agent.step.call_count, calls_to_run_episode)
     self.assertEqual(self._agent.end_episode.call_count,
                      calls_to_run_episode)
     self.assertEqual(max_steps, step_number)
     self.assertEqual(-1 * calls_to_run_episode, sum_returns)
     self.assertEqual(calls_to_run_episode, num_episodes)
     expected_statistics = []
     for _ in range(calls_to_run_episode):
         expected_statistics.append({
             'test_episode_lengths': 2,
             'test_episode_returns': -1
         })
     self.assertEqual(len(expected_statistics), len(statistics))
     for i in range(len(statistics)):
         self.assertDictEqual(expected_statistics[i], statistics[i])
 def testRunOneIteration(self):
     environment_steps = 2
     environment = MockEnvironment(max_steps=environment_steps)
     training_steps = 20
     evaluation_steps = 10
     runner = run_experiment.Runner(self._test_subdir,
                                    self._create_agent_fn,
                                    lambda: environment,
                                    training_steps=training_steps,
                                    evaluation_steps=evaluation_steps)
     dictionary = runner._run_one_iteration(1)
     train_calls = int(training_steps / environment_steps)
     eval_calls = int(evaluation_steps / environment_steps)
     expected_dictionary = {
         'train_episode_lengths': [2 for _ in range(train_calls)],
         'train_episode_returns': [-1 for _ in range(train_calls)],
         'train_average_return': [-1],
         'eval_episode_lengths': [2 for _ in range(eval_calls)],
         'eval_episode_returns': [-1 for _ in range(eval_calls)],
         'eval_average_return': [-1]
     }
     for k in expected_dictionary:
         self.assertEqual(expected_dictionary[k], dictionary[k])
     # Also verify that average number of steps per second is present and
     # positive.
     self.assertEqual(len(dictionary['train_average_steps_per_second']), 1)
     self.assertGreater(dictionary['train_average_steps_per_second'][0], 0)
示例#4
0
def create_runner(base_dir,
                  create_agent_fn,
                  schedule='continuous_train_and_eval'):
    """Creates an experiment Runner.

  TODO(b/): Figure out the right idiom to create a Runner. The current mechanism
  of using a number of flags will not scale and is not elegant.

  Args:
    base_dir: Base directory for hosting all subdirectories.
    create_agent_fn: A function that takes as args a Tensorflow session and a
      Gym Atari 2600 environment, and returns an agent.
    schedule: string, which type of Runner to use.

  Returns:
    runner: A `run_experiment.Runner` like object.

  Raises:
    ValueError: When an unknown schedule is encountered.
  """
    assert base_dir is not None
    # Continuously runs training and eval till max num_iterations is hit.
    if schedule == 'continuous_train_and_eval':
        return run_experiment.Runner(base_dir, create_agent_fn,
                                     atari_lib.create_atari_environment)
    # Continuously runs training till maximum num_iterations is hit.
    elif schedule == 'continuous_train':
        return run_experiment.TrainRunner(base_dir, create_agent_fn,
                                          atari_lib.create_atari_environment)
    else:
        raise ValueError('Unknown schedule: {}'.format(schedule))
示例#5
0
    def testCheckpointExperiment(self, mock_logger_constructor,
                                 mock_checkpointer_constructor):
        checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints')
        test_dict = {'test': 1}
        iteration = 1729

        def bundle_and_checkpoint(x, y):
            self.assertEqual(checkpoint_dir, x)
            self.assertEqual(iteration, y)
            return test_dict

        self._agent.bundle_and_checkpoint.side_effect = bundle_and_checkpoint
        experiment_checkpointer = mock.Mock()
        mock_checkpointer_constructor.return_value = experiment_checkpointer
        logs_data = {'one': 1, 'two': 2}
        mock_logger = MockLogger(run_asserts=False, data=logs_data)
        mock_logger_constructor.return_value = mock_logger
        runner = run_experiment.Runner(self._test_subdir,
                                       self._create_agent_fn, mock.Mock)
        runner._checkpoint_experiment(iteration)
        self.assertEqual(1, experiment_checkpointer.save_checkpoint.call_count)
        mock_args, _ = experiment_checkpointer.save_checkpoint.call_args
        self.assertEqual(iteration, mock_args[0])
        test_dict['logs'] = logs_data
        test_dict['current_iteration'] = iteration
        self.assertDictEqual(test_dict, mock_args[1])
 def testRunOneEpisodeWithLowMaxSteps(self):
   max_steps_per_episode = 2
   environment = MockEnvironment()
   runner = run_experiment.Runner(
       self._test_subdir, self.__fn, lambda: environment,
       max_steps_per_episode=max_steps_per_episode)
   step_number, total_reward = runner._run_one_episode()
   self.assertEqual(self._agent.step.call_count, max_steps_per_episode - 1)
   self.assertEqual(self._agent.end_episode.call_count, 1)
   self.assertEqual(max_steps_per_episode, step_number)
   self.assertEqual(-1, total_reward)
 def testRunOneEpisode(self):
   max_steps_per_episode = 11
   environment = MockEnvironment()
   runner = run_experiment.Runner(
       self._test_subdir, self.__fn, lambda: environment,
       max_steps_per_episode=max_steps_per_episode)
   step_number, total_reward = runner._run_one_episode()
   self.assertEqual(self._agent.step.call_count, environment.max_steps - 1)
   self.assertEqual(self._agent.end_episode.call_count, 1)
   self.assertEqual(environment.max_steps, step_number)
   # Expected reward will be \sum_{i=0}^{9} (-1)**i * i = -5
   self.assertEqual(-5, total_reward)
 def testRunExperimentWithInconsistentRange(self, mock_logger_constructor,
                                            mock_checkpointer_constructor):
   experiment_logger = MockLogger()
   mock_logger_constructor.return_value = experiment_logger
   experiment_checkpointer = mock.Mock()
   mock_checkpointer_constructor.return_value = experiment_checkpointer
   runner = run_experiment.Runner(
       self._test_subdir, self.__fn, mock.Mock,
       num_iterations=0)
   runner.run_experiment()
   self.assertEqual(0, experiment_checkpointer.save_checkpoint.call_count)
   self.assertEqual(0, experiment_logger._calls_to_set)
   self.assertEqual(0, experiment_logger._calls_to_log)
示例#9
0
 def testDefaultGinDqn(self):
   """Test DQNAgent configuration using the default gin config."""
   tf.logging.info('####### Training the DQN agent #####')
   tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir))
   gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
   gin_bindings = [
       'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
       "create_agent.agent_name = 'dqn'"
   ]
   run_experiment.load_gin_configs(gin_files, gin_bindings)
   runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                  atari_lib.create_atari_environment)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertNear(0.00025, runner._agent.optimizer._learning_rate, 0.0001)
   shutil.rmtree(self._base_dir)
示例#10
0
 def testDefaultGinRainbow(self):
     """Test RainbowAgent default configuration using default gin."""
     tf.logging.info('####### Training the RAINBOW agent #####')
     tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir))
     gin_files = ['dopamine/agents/rainbow/configs/rainbow.gin']
     gin_bindings = [
         'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
         "create_agent.agent_name = 'rainbow'"
     ]
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     runner = run_experiment.Runner(self._base_dir,
                                    run_experiment.create_agent,
                                    atari_lib.create_atari_environment)
     self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
     self.assertNear(0.0000625, runner._agent.optimizer._lr, 0.0001)
     shutil.rmtree(self._base_dir)
 def testLogExperiment(self, mock_logger_constructor):
   log_every_n = 2
   logging_file_prefix = 'prefix'
   statistics = 'statistics'
   experiment_logger = MockLogger(test_cls=self)
   mock_logger_constructor.return_value = experiment_logger
   runner = run_experiment.Runner(
       self._test_subdir, self.__fn, mock.Mock,
       logging_file_prefix=logging_file_prefix,
       log_every_n=log_every_n)
   num_iterations = 10
   for i in range(num_iterations):
     runner._log_experiment(i, statistics)
   self.assertEqual(num_iterations, experiment_logger._calls_to_set)
   self.assertEqual((num_iterations / log_every_n),
                    experiment_logger._calls_to_log)
示例#12
0
 def testRewardClipping(self, reward_clipping, reward, expected_reward):
     environment = tf.test.mock.Mock()
     environment.step.return_value = (0, reward, True, {})
     mock_agent = tf.test.mock.Mock()
     agent_fn = tf.test.mock.MagicMock(return_value=mock_agent)
     runner = run_experiment.Runner(self.get_temp_dir(),
                                    agent_fn,
                                    lambda: environment,
                                    log_every_n=1,
                                    num_iterations=1,
                                    training_steps=1,
                                    evaluation_steps=0,
                                    reward_clipping=reward_clipping)
     runner._checkpoint_experiment = tf.test.mock.Mock()
     runner._log_experiment = tf.test.mock.Mock()
     runner.run_experiment()
     mock_agent.end_episode.assert_called_once_with(expected_reward)
示例#13
0
 def testDefaultGinImplicitQuantileIcml(self):
   """Test default ImplicitQuantile configuration using ICML gin."""
   tf.logging.info('###### Training the Implicit Quantile agent #####')
   self._base_dir = os.path.join(
       '/tmp/dopamine_tests',
       datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S'))
   tf.logging.info('###### IQN base dir: {}'.format(self._base_dir))
   gin_files = ['dopamine/agents/'
                'implicit_quantile/configs/implicit_quantile_icml.gin']
   gin_bindings = [
       'Runner.num_iterations=0',
   ]
   run_experiment.load_gin_configs(gin_files, gin_bindings)
   runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                  atari_lib.create_atari_environment)
   self.assertEqual(1000000, runner._agent._replay.memory._replay_capacity)
   shutil.rmtree(self._base_dir)
示例#14
0
  def testOverrideGinDqn(self):
    """Test DQNAgent configuration overridden with AdamOptimizer."""
    tf.logging.info('####### Training the DQN agent #####')
    tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir))
    gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
    gin_bindings = [
        'DQNAgent.optimizer = @tf.train.AdamOptimizer()',
        'tf.train.AdamOptimizer.learning_rate = 100',
        'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
        "create_agent.agent_name = 'dqn'"
    ]

    run_experiment.load_gin_configs(gin_files, gin_bindings)
    runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                   atari_lib.create_atari_environment)
    self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
    self.assertEqual(100, runner._agent.optimizer._lr)
    shutil.rmtree(self._base_dir)
示例#15
0
 def testOverrideGinRainbow(self):
   """Test RainbowAgent configuration overridden with RMSPropOptimizer."""
   tf.logging.info('####### Training the RAINBOW agent #####')
   tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir))
   gin_files = [
       'dopamine/agents/rainbow/configs/rainbow.gin',
   ]
   gin_bindings = [
       'RainbowAgent.optimizer = @tf.train.RMSPropOptimizer()',
       'tf.train.RMSPropOptimizer.learning_rate = 100',
       'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
       "create_agent.agent_name = 'rainbow'"
   ]
   run_experiment.load_gin_configs(gin_files, gin_bindings)
   runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                  atari_lib.create_atari_environment)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertEqual(100, runner._agent.optimizer._learning_rate)
   shutil.rmtree(self._base_dir)
示例#16
0
    def testRunExperiment(self, mock_logger_constructor,
                          mock_checkpointer_constructor, mock_get_latest):
        log_every_n = 1
        environment = MockEnvironment()
        experiment_logger = MockLogger(run_asserts=False)
        mock_logger_constructor.return_value = experiment_logger
        experiment_checkpointer = mock.Mock()
        start_iteration = 1729
        mock_get_latest.return_value = start_iteration

        def load_checkpoint(_):
            return {
                'logs': 'log_data',
                'current_iteration': start_iteration - 1
            }

        experiment_checkpointer.load_checkpoint.side_effect = load_checkpoint
        mock_checkpointer_constructor.return_value = experiment_checkpointer

        def bundle_and_checkpoint(x, y):
            del x, y  # Unused.
            return {'test': 1}

        self._agent.bundle_and_checkpoint.side_effect = bundle_and_checkpoint
        num_iterations = 10
        self._agent.unbundle.return_value = True
        end_iteration = start_iteration + num_iterations
        runner = run_experiment.Runner(self._test_subdir,
                                       self._create_agent_fn,
                                       lambda: environment,
                                       log_every_n=log_every_n,
                                       num_iterations=end_iteration,
                                       training_steps=1,
                                       evaluation_steps=1)
        self.assertEqual(start_iteration, runner._start_iteration)
        runner.run_experiment()
        self.assertEqual(num_iterations,
                         experiment_checkpointer.save_checkpoint.call_count)
        self.assertEqual(num_iterations, experiment_logger._calls_to_set)
        self.assertEqual(num_iterations, experiment_logger._calls_to_log)
        glob_string = '{}/events.out.tfevents.*'.format(self._test_subdir)
        self.assertGreater(len(tf.gfile.Glob(glob_string)), 0)
 def create_runner(self, env_fn, hparams, target_iterations,
                   training_steps_per_iteration):
     # pylint: disable=unbalanced-tuple-unpacking
     agent_params, optimizer_params, \
     runner_params, replay_buffer_params = _parse_hparams(hparams)
     # pylint: enable=unbalanced-tuple-unpacking
     optimizer = _get_optimizer(optimizer_params)
     agent_params["optimizer"] = optimizer
     agent_params.update(replay_buffer_params)
     create_agent_fn = get_create_agent(agent_params)
     runner = run_experiment.Runner(
         base_dir=self.agent_model_dir,
         create_agent_fn=create_agent_fn,
         create_environment_fn=get_create_env_fun(
             env_fn, time_limit=hparams.time_limit),
         evaluation_steps=0,
         num_iterations=target_iterations,
         training_steps=training_steps_per_iteration,
         **runner_params)
     return runner
示例#18
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.disable_v2_behavior()
    base_dir = FLAGS.base_dir
    gin_files, gin_bindings = FLAGS.gin_files, FLAGS.gin_bindings
    run_experiment.load_gin_configs(gin_files, gin_bindings)
    # Set the Jax agent seed using the run number
    create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number)
    if FLAGS.max_episode_eval:
        runner_fn = eval_run_experiment.MaxEpisodeEvalRunner
        logging.info('Using MaxEpisodeEvalRunner for evaluation.')
        runner = runner_fn(base_dir, create_agent_fn)
    else:
        runner = run_experiment.Runner(base_dir, create_agent_fn)
    runner.run_experiment()
 def testRunOneIteration(self):
   environment_steps = 2
   environment = MockEnvironment(max_steps=environment_steps)
   training_steps = 20
   evaluation_steps = 10
   runner = run_experiment.Runner(
       self._test_subdir, self.__fn, lambda: environment,
       training_steps=training_steps,
       evaluation_steps=evaluation_steps)
   dictionary = runner._run_one_iteration(1)
   train_calls = int(training_steps / environment_steps)
   eval_calls = int(evaluation_steps / environment_steps)
   expected_dictionary = {
       'train_episode_lengths': [2 for _ in range(train_calls)],
       'train_episode_returns': [-1 for _ in range(train_calls)],
       'train_average_return': [-1],
       'eval_episode_lengths': [2 for _ in range(eval_calls)],
       'eval_episode_returns': [-1 for _ in range(eval_calls)],
       'eval_average_return': [-1]
   }
   self.assertDictEqual(expected_dictionary, dictionary)
 def testInitializeCheckpointingWhenCheckpointUnbundleSucceeds(
     self, mock_get_latest):
   latest_checkpoint = 7
   mock_get_latest.return_value = latest_checkpoint
   logs_data = {'a': 1, 'b': 2}
   current_iteration = 1729
   checkpoint_data = {'current_iteration': current_iteration,
                      'logs': logs_data}
   checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints')
   checkpoint = checkpointer.Checkpointer(checkpoint_dir, 'ckpt')
   checkpoint.save_checkpoint(latest_checkpoint, checkpoint_data)
   mock_agent = mock.Mock()
   mock_agent.unbundle.return_value = True
   runner = run_experiment.Runner(self._test_subdir,
                                  lambda x, y, summary_writer: mock_agent,
                                  mock.Mock)
   expected_iteration = current_iteration + 1
   self.assertEqual(expected_iteration, runner._start_iteration)
   self.assertDictEqual(logs_data, runner._logger.data)
   mock_agent.unbundle.assert_called_once_with(
       checkpoint_dir, latest_checkpoint, checkpoint_data)
示例#21
0
def dopamine_train(base_dir,
                   hidden_layer_size,
                   gamma,
                   learning_rate,
                   num_train_steps,
                   network='chain'):
    """Train an agent using dopamine."""
    runner = run_experiment.Runner(base_dir,
                                   functools.partial(
                                       _create_agent,
                                       hidden_layer_size=hidden_layer_size,
                                       gamma=gamma,
                                       learning_rate=learning_rate),
                                   functools.partial(_create_environment,
                                                     network=network),
                                   num_iterations=num_train_steps,
                                   training_steps=500,
                                   evaluation_steps=100,
                                   max_steps_per_episode=20)
    runner.run_experiment()
    return runner
示例#22
0
 def testOverrideGinImplicitQuantile(self):
     """Test ImplicitQuantile configuration overriding using IQN gin."""
     logging.info('###### Training the Implicit Quantile agent #####')
     self._base_dir = os.path.join(
         '/tmp/dopamine_tests',
         datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S'))
     logging.info('###### IQN base dir: %s', self._base_dir)
     gin_files = [
         'dopamine/agents/implicit_quantile/configs/'
         'implicit_quantile.gin'
     ]
     gin_bindings = [
         'Runner.num_iterations=0',
         'WrappedPrioritizedReplayBuffer.replay_capacity = 1000',
     ]
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     runner = run_experiment.Runner(self._base_dir,
                                    run_experiment.create_agent,
                                    atari_lib.create_atari_environment)
     self.assertEqual(1000, runner._agent._replay.memory._replay_capacity)
     shutil.rmtree(self._base_dir)
示例#23
0
 def testInitializeCheckpointingWhenCheckpointUnbundleFails(
         self, mock_logger_constructor, mock_checkpointer_constructor,
         mock_get_latest):
     mock_checkpointer = _create_mock_checkpointer()
     mock_checkpointer_constructor.return_value = mock_checkpointer
     latest_checkpoint = 7
     mock_get_latest.return_value = latest_checkpoint
     agent = mock.Mock()
     agent.unbundle.return_value = False
     mock_logger = mock.Mock()
     mock_logger_constructor.return_value = mock_logger
     runner = run_experiment.Runner(self._test_subdir,
                                    lambda x, y, summary_writer: agent,
                                    mock.Mock)
     self.assertEqual(0, runner._start_iteration)
     self.assertEqual(1, mock_checkpointer.load_checkpoint.call_count)
     self.assertEqual(1, agent.unbundle.call_count)
     mock_args, _ = agent.unbundle.call_args
     self.assertEqual('{}/checkpoints'.format(self._test_subdir),
                      mock_args[0])
     self.assertEqual(latest_checkpoint, mock_args[1])
     expected_dictionary = {'current_iteration': 1729, 'logs': 'logs'}
     self.assertDictEqual(expected_dictionary, mock_args[2])
示例#24
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
  runner = run_experiment.Runner(FLAGS.base_dir, create_agent)
  runner.run_experiment()
def make_runner(gin_files):
    run_experiment.load_gin_configs(gin_files, None)
    runner = run_experiment.Runner(base_dir=base_dir,
                                   create_agent_fn=create_agent_fn,
                                   create_environment_fn=create_env_fn)
    return runner
示例#26
0
create_agent.agent_name = 'ac'

run_experiment.Runner.create_environment_fn = @gym_lib.create_gym_environment
run_experiment.Runner.num_iterations = 500
run_experiment.Runner.training_steps = 1000
run_experiment.Runner.evaluation_steps = 1000
run_experiment.Runner.max_steps_per_episode = 200

circular_replay_buffer.WrappedReplayBuffer.replay_capacity = 50000
circular_replay_buffer.WrappedReplayBuffer.batch_size = 128

ACAgent.actor_network = @CartpoleActorNetwork
ACAgent.critic_network = @CartpoleCriticNetwork
ACAgent.observation_shape = %gym_lib.CARTPOLE_OBSERVATION_SHAPE
ACAgent.observation_dtype = %gym_lib.CARTPOLE_OBSERVATION_DTYPE
ACAgent.stack_size = %gym_lib.CARTPOLE_STACK_SIZE


tf.train.AdamOptimizer.learning_rate = 0.00001
tf.train.AdamOptimizer.epsilon = 0.00000390625
"""

tf.logging.set_verbosity(tf.logging.INFO)

gin.parse_config(ac_config, skip_unknown=False)

ac_runner = run_experiment.Runner(LOG_PATH, create_ac_agent)

print('Will train ac agent, please be patient, may be a while...')
ac_runner.run_experiment()
print('Done training!')
示例#27
0
def run(bsuite_id: str) -> str:
    """Runs Dopamine DQN on a given bsuite environment, logging to CSV."""

    raw_env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )

    class Network(tf.keras.Model):
        """Build deep network compatible with dopamine/discrete_domains/gym_lib."""
        def __init__(self, num_actions: int, name='Network'):
            super(Network, self).__init__(name=name)
            self.forward_fn = tf.keras.Sequential(
                [tf.keras.layers.Flatten()] + [
                    tf.keras.layers.Dense(FLAGS.num_units,
                                          activation=tf.keras.activations.relu)
                    for _ in range(FLAGS.num_hidden_layers)
                ] + [tf.keras.layers.Dense(num_actions, activation=None)])

        def call(self, state):
            """Creates the output tensor/op given the state tensor as input."""
            x = tf.cast(state, tf.float32)
            x = self.forward_fn(x)
            return atari_lib.DQNNetworkType(x)

    def create_agent(sess: tf.Session,
                     environment: gym.Env,
                     summary_writer=None):
        """Factory method for agent initialization in Dopmamine."""
        del summary_writer
        return dqn_agent.DQNAgent(
            sess=sess,
            num_actions=environment.action_space.n,
            observation_shape=OBSERVATION_SHAPE,
            observation_dtype=tf.float32,
            stack_size=1,
            network=Network,
            gamma=FLAGS.agent_discount,
            update_horizon=1,
            min_replay_history=FLAGS.min_replay_size,
            update_period=FLAGS.sgd_period,
            target_update_period=FLAGS.target_update_period,
            epsilon_decay_period=FLAGS.epsilon_decay_period,
            epsilon_train=FLAGS.epsilon,
            optimizer=tf.train.AdamOptimizer(FLAGS.learning_rate),
        )

    def create_environment() -> gym.Env:
        """Factory method for environment initialization in Dopmamine."""
        env = wrappers.ImageObservation(raw_env, OBSERVATION_SHAPE)
        if FLAGS.verbose:
            env = terminal_logging.wrap_environment(env, log_every=True)  # pytype: disable=wrong-arg-types
        env = gym_wrapper.GymFromDMEnv(env)
        env.game_over = False  # Dopamine looks for this
        return env

    runner = run_experiment.Runner(
        base_dir=FLAGS.base_dir,
        create_agent_fn=create_agent,
        create_environment_fn=create_environment,
    )

    num_episodes = FLAGS.num_episodes or getattr(raw_env,
                                                 'bsuite_num_episodes')
    for _ in range(num_episodes):
        runner._run_one_episode()  # pylint: disable=protected-access

    return bsuite_id
示例#28
0
env_path = '../env/AnimalAI'
worker_id = random.randint(1, 100)
arena_config_in = ArenaConfig('configs/1-Food.yaml')
base_dir = 'models/dopamine'
gin_files = ['configs/rainbow.gin']


def create_env_fn():
    env = AnimalAIEnv(environment_filename=env_path,
                      worker_id=worker_id,
                      n_arenas=1,
                      arenas_configurations=arena_config_in,
                      docker_training=False,
                      retro=True)
    print("all good create_env_fn")
    return env


def create_agent_fn(sess, env, summary_writer):
    return rainbow_agent.RainbowAgent(sess=sess,
                                      num_actions=env.action_space.n,
                                      summary_writer=summary_writer)


run_experiment.load_gin_configs(gin_files, None)
runner = run_experiment.Runner(base_dir=base_dir,
                               create_agent_fn=create_agent_fn,
                               create_environment_fn=create_env_fn)
# runner.run_experiment()
示例#29
0
def main(unused_argv):
    config.update('jax_disable_jit', FLAGS.disable_jit)
    logging.set_verbosity(logging.INFO)
    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    runner = run_experiment.Runner(FLAGS.base_dir, create_agent)
    runner.run_experiment()
示例#30
0
 def testInitializeCheckpointingWithNoCheckpointFile(self, mock_get_latest):
     mock_get_latest.return_value = -1
     base_dir = '/does/not/exist'
     with self.assertRaisesRegexp(tf.errors.PermissionDeniedError,
                                  '.*/does.*'):
         run_experiment.Runner(base_dir, self._create_agent_fn, mock.Mock)