def main(unused_argv):
    """Main method.
    Args:
      unused_argv: Arguments (unused).
    """
    # init logging
    tf.logging.set_verbosity(tf.logging.INFO)
    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    ginfile = str(FLAGS.gin_files[0])
    experiment_name = ginfile[ginfile.rfind('/') + 1: ginfile.rfind('.gin')]
    log_dir = os.path.join(FLAGS.base_dir, experiment_name)
    runtime_file = os.path.join(FLAGS.base_dir, 'runtime', 'runtime.csv')
    inference_file = os.path.join(FLAGS.base_dir, 'runtime', 'inference.csv')

    runner = checkpoint_runner.create_runner(log_dir)
    start_time = time.time()
    runner.run_experiment()
    end_time = time.time()

    f = open(runtime_file, 'a+')
    f.write(experiment_name + ', ' + str(end_time - start_time) + '\n')
    f.close()

    if runner.inference_steps:
        print("--- STARTING DOPAMINE CARTPOLE INFERENCE EXPERIMENT ---\n")
        start_time = time.time()
        runner.run_inference_test()
        end_time = time.time()
        f = open(inference_file, 'a+')
        f.write(experiment_name + ', ' +
                str(end_time - start_time) + '\n')
        f.close()
        print("--- DOPAMINE CARTPOLE INFERENCE EXPERIMENT COMPLETED ---\n")
Ejemplo n.º 2
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs')
    create_agent_fn = functools.partial(create_agent,
                                        replay_data_dir=replay_data_dir)
    runner = run_experiment.FixedReplayRunner(FLAGS.base_dir, create_agent_fn)
    runner.run_experiment()
Ejemplo n.º 3
0
 def __init__(self, atari_roms_source, atari_roms_path, gin_files,
              gin_bindings, random_seed, no_op, best_sampling):
     atari_lib.copy_roms(atari_roms_source, destination_dir=atari_roms_path)
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     self.random_state = np.random.RandomState(random_seed)
     self.runner = ConqurRunner(self.random_state, no_op, best_sampling)
     self.num_actions = self.runner.num_actions
     super(Atari, self).__init__()
Ejemplo n.º 4
0
 def testLoadGinConfigs(self, mock_parse_config_files_and_bindings):
     gin_files = ['file1', 'file2', 'file3']
     gin_bindings = ['binding1', 'binding2']
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     self.assertEqual(1, mock_parse_config_files_and_bindings.call_count)
     mock_args, mock_kwargs = mock_parse_config_files_and_bindings.call_args
     self.assertEqual(gin_files, mock_args[0])
     self.assertEqual(gin_bindings, mock_kwargs['bindings'])
     self.assertFalse(mock_kwargs['skip_unknown'])
Ejemplo n.º 5
0
def run(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  run_experiment.load_gin_configs([FLAGS.gin_file], [])
  # Create the replay log dir.
  replay_log_dir = os.path.join(FLAGS.base_dir, 'replay_logs')
  tf.logging.info('Saving replay buffer data to {}'.format(replay_log_dir))
  create_agent_fn = functools.partial(
      create_agent, replay_log_dir=replay_log_dir)
  runner = LoggedRunner(FLAGS.base_dir, create_agent_fn)
  runner.run_experiment()
Ejemplo n.º 6
0
def main(unused_argv):
  """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
  tf.logging.set_verbosity(tf.logging.INFO)
  run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
  runner = run_experiment.create_runner(FLAGS.base_dir)
  runner.run_experiment()
Ejemplo n.º 7
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    from . import bubble_runner
    tf.logging.set_verbosity(tf.logging.INFO)
    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    runner = bubble_runner.create_runner(FLAGS.base_dir, level=FLAGS.level)
    runner.run_experiment()
def main(unused_argv):
    """This main function acts as a wrapper around a gin-configurable experiment.

  Args:
    unused_argv: Arguments (unused).
  """
    tf.logging.set_verbosity(tf.logging.INFO)
    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    runner = run_experiment_from_checkpoint.create_runner_checkpoint(
        FLAGS.base_dir, run_experiment.create_agent, schedule='save_best')
    runner.run_experiment()
Ejemplo n.º 9
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.disable_v2_behavior()

    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    runner = run_experiment.create_runner(FLAGS.base_dir)
    runner.run_experiment()
Ejemplo n.º 10
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    tf.logging.set_verbosity(tf.logging.INFO)
    print("levin: test ok")

    run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    FLAGS.base_dir = FLAGS.base_dir + '-' + str(int(time.time()))
    runner = run_experiment.create_runner(FLAGS.base_dir)
    runner.run_experiment()
Ejemplo n.º 11
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    logging.set_verbosity(logging.INFO)
    base_dir = FLAGS.base_dir
    gin_files = FLAGS.gin_files
    gin_bindings = FLAGS.gin_bindings
    run_experiment.load_gin_configs(gin_files, gin_bindings)
    runner = run_experiment.TrainRunner(base_dir, create_metric_agent)
    runner.run_experiment()
Ejemplo n.º 12
0
def launch_experiment(create_runner_fn, create_agent_fn):
  """Launches the experiment.

  Args:
    create_runner_fn: A function that takes as args a base directory and a
      function for creating an agent and returns a `Runner` like object.
    create_agent_fn: A function that takes as args a Tensorflow session and a
     Gym Atari 2600 environment, and returns an agent.
  """

  run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
  runner = create_runner_fn(FLAGS.base_dir, create_agent_fn,
                            schedule=FLAGS.schedule)
  runner.run_experiment()
Ejemplo n.º 13
0
def main(unused_argv):
  """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
  tf.logging.set_verbosity(tf.logging.INFO)
  gin_file = 'bisimulation_aaai2020/dopamine/configs/rainbow.gin'
  dopamine_run_experiment.load_gin_configs([gin_file], FLAGS.gin_bindings)
  gin_binding = 'BisimulationRainbowAgent.evaluate_metric_only=True'
  FLAGS.gin_bindings.append(gin_binding)
  runner = run_experiment.create_runner(FLAGS.base_dir,
                                        FLAGS.metric_checkpoint)
  runner.visualize(num_global_steps=FLAGS.num_global_steps)
Ejemplo n.º 14
0
def test_complete_experiment():
    """
    Smoke test that runs small experiments for CartPole and ParkQOpt environemtns and fails if any exception during its execution was raised.
    """
    try:
        # init logging
        tf.logging.set_verbosity(tf.logging.ERROR)

        # configure experiment
        run_experiment.load_gin_configs(PARAMS, [])
        # create the agent and run experiment
        runner = checkpoint_runner.create_runner(BASE_DIR)
        runner.run_experiment()
    except Exception:
        pytest.fail('Running experiments in Dopamine failed!')
def main(unused_argv):
    """This main function acts as a wrapper around a gin-configurable experiment.

  Args:
    unused_argv: Arguments (unused).
  """
    del unused_argv
    gin_bindings = FLAGS.gin_bindings + \
        ['LoadFromRunner.original_base_dir="{}"'.format(FLAGS.original_base_dir)]

    tf.logging.set_verbosity(tf.logging.INFO)
    run_experiment.load_gin_configs(FLAGS.gin_files, gin_bindings)
    runner = run_experiment_from_checkpoint.create_runner_checkpoint(
        FLAGS.base_dir, run_experiment.create_agent, schedule='load_from_best')
    runner.run_experiment()
Ejemplo n.º 16
0
 def testDefaultGinDqn(self):
   """Test DQNAgent configuration using the default gin config."""
   tf.logging.info('####### Training the DQN agent #####')
   tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir))
   gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
   gin_bindings = [
       'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
       "create_agent.agent_name = 'dqn'"
   ]
   run_experiment.load_gin_configs(gin_files, gin_bindings)
   runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                  atari_lib.create_atari_environment)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertNear(0.00025, runner._agent.optimizer._learning_rate, 0.0001)
   shutil.rmtree(self._base_dir)
Ejemplo n.º 17
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.disable_v2_behavior()

    base_dir = FLAGS.base_dir
    gin_files = FLAGS.gin_files
    gin_bindings = FLAGS.gin_bindings
    base_run_experiment.load_gin_configs(gin_files, gin_bindings)
    runner = run_experiment.TandemRunner(
        base_dir, run_experiment.create_tandem_agents_and_checkpoints)
    runner.run_experiment()
Ejemplo n.º 18
0
 def testDefaultGinRainbow(self):
     """Test RainbowAgent default configuration using default gin."""
     tf.logging.info('####### Training the RAINBOW agent #####')
     tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir))
     gin_files = ['dopamine/agents/rainbow/configs/rainbow.gin']
     gin_bindings = [
         'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
         "create_agent.agent_name = 'rainbow'"
     ]
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     runner = run_experiment.Runner(self._base_dir,
                                    run_experiment.create_agent,
                                    atari_lib.create_atari_environment)
     self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
     self.assertNear(0.0000625, runner._agent.optimizer._lr, 0.0001)
     shutil.rmtree(self._base_dir)
Ejemplo n.º 19
0
 def testDefaultDQNConfig(self):
     """Verify the default DQN configuration."""
     run_experiment.load_gin_configs(
         ['dopamine/agents/dqn/configs/dqn.gin'], [])
     agent = run_experiment.create_agent(
         self.test_session(),
         atari_lib.create_atari_environment(game_name='Pong'))
     self.assertEqual(agent.gamma, 0.99)
     self.assertEqual(agent.update_horizon, 1)
     self.assertEqual(agent.min_replay_history, 20000)
     self.assertEqual(agent.update_period, 4)
     self.assertEqual(agent.target_update_period, 8000)
     self.assertEqual(agent.epsilon_train, 0.01)
     self.assertEqual(agent.epsilon_eval, 0.001)
     self.assertEqual(agent.epsilon_decay_period, 250000)
     self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
     self.assertEqual(agent._replay.memory._batch_size, 32)
Ejemplo n.º 20
0
 def testDefaultGinImplicitQuantileIcml(self):
   """Test default ImplicitQuantile configuration using ICML gin."""
   tf.logging.info('###### Training the Implicit Quantile agent #####')
   self._base_dir = os.path.join(
       '/tmp/dopamine_tests',
       datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S'))
   tf.logging.info('###### IQN base dir: {}'.format(self._base_dir))
   gin_files = ['dopamine/agents/'
                'implicit_quantile/configs/implicit_quantile_icml.gin']
   gin_bindings = [
       'Runner.num_iterations=0',
   ]
   run_experiment.load_gin_configs(gin_files, gin_bindings)
   runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                  atari_lib.create_atari_environment)
   self.assertEqual(1000000, runner._agent._replay.memory._replay_capacity)
   shutil.rmtree(self._base_dir)
Ejemplo n.º 21
0
  def testOverrideGinDqn(self):
    """Test DQNAgent configuration overridden with AdamOptimizer."""
    tf.logging.info('####### Training the DQN agent #####')
    tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir))
    gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
    gin_bindings = [
        'DQNAgent.optimizer = @tf.train.AdamOptimizer()',
        'tf.train.AdamOptimizer.learning_rate = 100',
        'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
        "create_agent.agent_name = 'dqn'"
    ]

    run_experiment.load_gin_configs(gin_files, gin_bindings)
    runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                   atari_lib.create_atari_environment)
    self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer)
    self.assertEqual(100, runner._agent.optimizer._lr)
    shutil.rmtree(self._base_dir)
Ejemplo n.º 22
0
 def testOverrideRunnerParams(self):
     """Test DQNAgent configuration using the default gin config."""
     tf.logging.info('####### Training the DQN agent #####')
     tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir))
     gin_files = ['dopamine/agents/dqn/configs/dqn.gin']
     gin_bindings = [
         'TrainRunner.base_dir = "{}"'.format(self._base_dir),
         'Runner.log_every_n = 1729',
         'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
         "create_agent.agent_name = 'dqn'"
     ]
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     runner = run_experiment.TrainRunner(
         create_agent_fn=run_experiment.create_agent,
         create_environment_fn=atari_lib.create_atari_environment)
     self.assertEqual(runner._base_dir, self._base_dir)
     self.assertEqual(runner._log_every_n, 1729)
     shutil.rmtree(self._base_dir)
Ejemplo n.º 23
0
 def testOverrideGinRainbow(self):
   """Test RainbowAgent configuration overridden with RMSPropOptimizer."""
   tf.logging.info('####### Training the RAINBOW agent #####')
   tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir))
   gin_files = [
       'dopamine/agents/rainbow/configs/rainbow.gin',
   ]
   gin_bindings = [
       'RainbowAgent.optimizer = @tf.train.RMSPropOptimizer()',
       'tf.train.RMSPropOptimizer.learning_rate = 100',
       'WrappedReplayBuffer.replay_capacity = 100',  # To prevent OOM.
       "create_agent.agent_name = 'rainbow'"
   ]
   run_experiment.load_gin_configs(gin_files, gin_bindings)
   runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent,
                                  atari_lib.create_atari_environment)
   self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer)
   self.assertEqual(100, runner._agent.optimizer._learning_rate)
   shutil.rmtree(self._base_dir)
Ejemplo n.º 24
0
def main(unused_argv):
    """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.disable_v2_behavior()
    base_dir = FLAGS.base_dir
    gin_files, gin_bindings = FLAGS.gin_files, FLAGS.gin_bindings
    run_experiment.load_gin_configs(gin_files, gin_bindings)
    # Set the Jax agent seed using the run number
    create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number)
    if FLAGS.max_episode_eval:
        runner_fn = eval_run_experiment.MaxEpisodeEvalRunner
        logging.info('Using MaxEpisodeEvalRunner for evaluation.')
        runner = runner_fn(base_dir, create_agent_fn)
    else:
        runner = run_experiment.Runner(base_dir, create_agent_fn)
    runner.run_experiment()
Ejemplo n.º 25
0
 def testOverrideGinImplicitQuantile(self):
     """Test ImplicitQuantile configuration overriding using IQN gin."""
     logging.info('###### Training the Implicit Quantile agent #####')
     self._base_dir = os.path.join(
         '/tmp/dopamine_tests',
         datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S'))
     logging.info('###### IQN base dir: %s', self._base_dir)
     gin_files = [
         'dopamine/agents/implicit_quantile/configs/'
         'implicit_quantile.gin'
     ]
     gin_bindings = [
         'Runner.num_iterations=0',
         'WrappedPrioritizedReplayBuffer.replay_capacity = 1000',
     ]
     run_experiment.load_gin_configs(gin_files, gin_bindings)
     runner = run_experiment.Runner(self._base_dir,
                                    run_experiment.create_agent,
                                    atari_lib.create_atari_environment)
     self.assertEqual(1000, runner._agent._replay.memory._replay_capacity)
     shutil.rmtree(self._base_dir)
Ejemplo n.º 26
0
def main(unused_argv):
    path, split = osp.split(FLAGS.exp_dir)
    path, game = osp.split(path)
    gin.bind_parameter('atari_lib.create_atari_environment.game_name', game)
    if FLAGS.use_preference_rewards:
        training_log_path = create_logs_for_training(FLAGS)
        agent_name = "_".join([
            FLAGS.agent_name, FLAGS.preference_model_type,
            FLAGS.reward_model_type
        ])
        FLAGS.replay_dir = training_log_path
    else:
        raise NotImplementedError
    tf.logging.set_verbosity(tf.logging.INFO)
    base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
    replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs')
    create_agent_fn = functools.partial(create_agent,
                                        replay_data_dir=replay_data_dir)
    runner = run_experiment.FixedReplayRunner(
        osp.join(FLAGS.exp_dir, agent_name), create_agent_fn)
    runner.run_experiment()
    pack_agents(FLAGS)
Ejemplo n.º 27
0
def main(unused_argv):
  """Main method.

  Args:
    unused_argv: Arguments (unused).
  """
  logging.set_verbosity(logging.INFO)

  xm_xid = None if 'xm_xid' not in FLAGS else FLAGS.xm_xid
  xm_wid = None if 'xm_wid' not in FLAGS else FLAGS.xm_wid
  xm_parameters = (None
                   if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters)
  base_dir, gin_files, gin_bindings = base_train.run_xm_preprocessing(
      xm_xid, xm_wid, xm_parameters, FLAGS.base_dir,
      FLAGS.custom_base_dir_from_hparams, FLAGS.gin_files, FLAGS.gin_bindings)
  create_agent = functools.partial(
      create_offline_agent,
      agent_name=FLAGS.agent_name
      )
  base_run_experiment.load_gin_configs(gin_files, gin_bindings)
  runner = run_experiment.FixedReplayRunner(base_dir, create_agent)
  runner.run_experiment()
Ejemplo n.º 28
0
 def testDefaultRainbowConfig(self):
     """Verify the default Rainbow configuration."""
     run_experiment.load_gin_configs(
         ['dopamine/agents/rainbow/configs/rainbow.gin'], [])
     agent = run_experiment.create_agent(
         self.test_session(),
         atari_lib.create_atari_environment(game_name='Pong'))
     self.assertEqual(agent._num_atoms, 51)
     support = self.evaluate(agent._support)
     self.assertEqual(min(support), -10.)
     self.assertEqual(max(support), 10.)
     self.assertEqual(len(support), 51)
     self.assertEqual(agent.gamma, 0.99)
     self.assertEqual(agent.update_horizon, 3)
     self.assertEqual(agent.min_replay_history, 20000)
     self.assertEqual(agent.update_period, 4)
     self.assertEqual(agent.target_update_period, 8000)
     self.assertEqual(agent.epsilon_train, 0.01)
     self.assertEqual(agent.epsilon_eval, 0.001)
     self.assertEqual(agent.epsilon_decay_period, 250000)
     self.assertEqual(agent._replay.memory._replay_capacity, 1000000)
     self.assertEqual(agent._replay.memory._batch_size, 32)
Ejemplo n.º 29
0
def main(configs):
    tf.logging.set_verbosity(tf.logging.INFO)
    base_run_experiment.load_gin_configs(configs.gin_files, configs.gin_bindings)
    replay_data_dir = os.path.join(configs.replay_dir, "replay_logs")
    create_agent_fn = functools.partial(
        create_agent,
        replay_data_dir=replay_data_dir,
        agent_name=configs.agent_name,
        init_checkpoint_dir=configs.init_checkpoint_dir,
    )
    create_environment_fn = functools.partial(create_environment)
    runner = run_experiment.FixedReplayRunner(
        configs.base_dir, create_agent_fn, create_environment_fn=create_environment_fn
    )

    dataset_path = os.path.join(
        os.path.realpath("."), "data/processed/v5_dataset/test_dataset_users/"
    )
    chkpt_path = os.path.join(
        os.path.realpath("."), "models/reward_pred_v0_model/release/80_input"
    )
    runner.set_offline_evaluation(dataset_path, chkpt_path)

    runner.run_experiment()
Ejemplo n.º 30
0
def main(_, hello='Hello'):
    # flag to use default config in bubble_agent.run.
    run_config = None

    # load gin configuration.
    if FLAGS.gin_files:
        from dopamine.discrete_domains import run_experiment
        print('! load gin-files:{}'.format(FLAGS.gin_files))
        FLAGS.gin_bindings.append(
            'retro_lib.create_retro_environment.level = {}'.format(
                FLAGS.level))
        print('! gin-bindings={}'.format(FLAGS.gin_bindings))
        run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
        run_config = ''  # do NOT load the default config in `bubble_agent.run`

    # run main.
    bubble_agent.run(agent=FLAGS.agent,
                     game=FLAGS.game,
                     level=FLAGS.level,
                     num_steps=FLAGS.steps,
                     root_dir=FLAGS.root_dir,
                     restore_ckpt=FLAGS.restore_checkpoint,
                     use_legacy_checkpoint=FLAGS.use_legacy_checkpoint,
                     config=run_config)