def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ # init logging tf.logging.set_verbosity(tf.logging.INFO) run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) ginfile = str(FLAGS.gin_files[0]) experiment_name = ginfile[ginfile.rfind('/') + 1: ginfile.rfind('.gin')] log_dir = os.path.join(FLAGS.base_dir, experiment_name) runtime_file = os.path.join(FLAGS.base_dir, 'runtime', 'runtime.csv') inference_file = os.path.join(FLAGS.base_dir, 'runtime', 'inference.csv') runner = checkpoint_runner.create_runner(log_dir) start_time = time.time() runner.run_experiment() end_time = time.time() f = open(runtime_file, 'a+') f.write(experiment_name + ', ' + str(end_time - start_time) + '\n') f.close() if runner.inference_steps: print("--- STARTING DOPAMINE CARTPOLE INFERENCE EXPERIMENT ---\n") start_time = time.time() runner.run_inference_test() end_time = time.time() f = open(inference_file, 'a+') f.write(experiment_name + ', ' + str(end_time - start_time) + '\n') f.close() print("--- DOPAMINE CARTPOLE INFERENCE EXPERIMENT COMPLETED ---\n")
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs') create_agent_fn = functools.partial(create_agent, replay_data_dir=replay_data_dir) runner = run_experiment.FixedReplayRunner(FLAGS.base_dir, create_agent_fn) runner.run_experiment()
def __init__(self, atari_roms_source, atari_roms_path, gin_files, gin_bindings, random_seed, no_op, best_sampling): atari_lib.copy_roms(atari_roms_source, destination_dir=atari_roms_path) run_experiment.load_gin_configs(gin_files, gin_bindings) self.random_state = np.random.RandomState(random_seed) self.runner = ConqurRunner(self.random_state, no_op, best_sampling) self.num_actions = self.runner.num_actions super(Atari, self).__init__()
def testLoadGinConfigs(self, mock_parse_config_files_and_bindings): gin_files = ['file1', 'file2', 'file3'] gin_bindings = ['binding1', 'binding2'] run_experiment.load_gin_configs(gin_files, gin_bindings) self.assertEqual(1, mock_parse_config_files_and_bindings.call_count) mock_args, mock_kwargs = mock_parse_config_files_and_bindings.call_args self.assertEqual(gin_files, mock_args[0]) self.assertEqual(gin_bindings, mock_kwargs['bindings']) self.assertFalse(mock_kwargs['skip_unknown'])
def run(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) run_experiment.load_gin_configs([FLAGS.gin_file], []) # Create the replay log dir. replay_log_dir = os.path.join(FLAGS.base_dir, 'replay_logs') tf.logging.info('Saving replay buffer data to {}'.format(replay_log_dir)) create_agent_fn = functools.partial( create_agent, replay_log_dir=replay_log_dir) runner = LoggedRunner(FLAGS.base_dir, create_agent_fn) runner.run_experiment()
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ tf.logging.set_verbosity(tf.logging.INFO) run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) runner = run_experiment.create_runner(FLAGS.base_dir) runner.run_experiment()
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ from . import bubble_runner tf.logging.set_verbosity(tf.logging.INFO) run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) runner = bubble_runner.create_runner(FLAGS.base_dir, level=FLAGS.level) runner.run_experiment()
def main(unused_argv): """This main function acts as a wrapper around a gin-configurable experiment. Args: unused_argv: Arguments (unused). """ tf.logging.set_verbosity(tf.logging.INFO) run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) runner = run_experiment_from_checkpoint.create_runner_checkpoint( FLAGS.base_dir, run_experiment.create_agent, schedule='save_best') runner.run_experiment()
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ logging.set_verbosity(logging.INFO) tf.compat.v1.disable_v2_behavior() run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) runner = run_experiment.create_runner(FLAGS.base_dir) runner.run_experiment()
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ tf.logging.set_verbosity(tf.logging.INFO) print("levin: test ok") run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) FLAGS.base_dir = FLAGS.base_dir + '-' + str(int(time.time())) runner = run_experiment.create_runner(FLAGS.base_dir) runner.run_experiment()
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ logging.set_verbosity(logging.INFO) base_dir = FLAGS.base_dir gin_files = FLAGS.gin_files gin_bindings = FLAGS.gin_bindings run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.TrainRunner(base_dir, create_metric_agent) runner.run_experiment()
def launch_experiment(create_runner_fn, create_agent_fn): """Launches the experiment. Args: create_runner_fn: A function that takes as args a base directory and a function for creating an agent and returns a `Runner` like object. create_agent_fn: A function that takes as args a Tensorflow session and a Gym Atari 2600 environment, and returns an agent. """ run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) runner = create_runner_fn(FLAGS.base_dir, create_agent_fn, schedule=FLAGS.schedule) runner.run_experiment()
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ tf.logging.set_verbosity(tf.logging.INFO) gin_file = 'bisimulation_aaai2020/dopamine/configs/rainbow.gin' dopamine_run_experiment.load_gin_configs([gin_file], FLAGS.gin_bindings) gin_binding = 'BisimulationRainbowAgent.evaluate_metric_only=True' FLAGS.gin_bindings.append(gin_binding) runner = run_experiment.create_runner(FLAGS.base_dir, FLAGS.metric_checkpoint) runner.visualize(num_global_steps=FLAGS.num_global_steps)
def test_complete_experiment(): """ Smoke test that runs small experiments for CartPole and ParkQOpt environemtns and fails if any exception during its execution was raised. """ try: # init logging tf.logging.set_verbosity(tf.logging.ERROR) # configure experiment run_experiment.load_gin_configs(PARAMS, []) # create the agent and run experiment runner = checkpoint_runner.create_runner(BASE_DIR) runner.run_experiment() except Exception: pytest.fail('Running experiments in Dopamine failed!')
def main(unused_argv): """This main function acts as a wrapper around a gin-configurable experiment. Args: unused_argv: Arguments (unused). """ del unused_argv gin_bindings = FLAGS.gin_bindings + \ ['LoadFromRunner.original_base_dir="{}"'.format(FLAGS.original_base_dir)] tf.logging.set_verbosity(tf.logging.INFO) run_experiment.load_gin_configs(FLAGS.gin_files, gin_bindings) runner = run_experiment_from_checkpoint.create_runner_checkpoint( FLAGS.base_dir, run_experiment.create_agent, schedule='load_from_best') runner.run_experiment()
def testDefaultGinDqn(self): """Test DQNAgent configuration using the default gin config.""" tf.logging.info('####### Training the DQN agent #####') tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/dqn/configs/dqn.gin'] gin_bindings = [ 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'dqn'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer) self.assertNear(0.00025, runner._agent.optimizer._learning_rate, 0.0001) shutil.rmtree(self._base_dir)
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ logging.set_verbosity(logging.INFO) tf.compat.v1.disable_v2_behavior() base_dir = FLAGS.base_dir gin_files = FLAGS.gin_files gin_bindings = FLAGS.gin_bindings base_run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.TandemRunner( base_dir, run_experiment.create_tandem_agents_and_checkpoints) runner.run_experiment()
def testDefaultGinRainbow(self): """Test RainbowAgent default configuration using default gin.""" tf.logging.info('####### Training the RAINBOW agent #####') tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/rainbow/configs/rainbow.gin'] gin_bindings = [ 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'rainbow'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer) self.assertNear(0.0000625, runner._agent.optimizer._lr, 0.0001) shutil.rmtree(self._base_dir)
def testDefaultDQNConfig(self): """Verify the default DQN configuration.""" run_experiment.load_gin_configs( ['dopamine/agents/dqn/configs/dqn.gin'], []) agent = run_experiment.create_agent( self.test_session(), atari_lib.create_atari_environment(game_name='Pong')) self.assertEqual(agent.gamma, 0.99) self.assertEqual(agent.update_horizon, 1) self.assertEqual(agent.min_replay_history, 20000) self.assertEqual(agent.update_period, 4) self.assertEqual(agent.target_update_period, 8000) self.assertEqual(agent.epsilon_train, 0.01) self.assertEqual(agent.epsilon_eval, 0.001) self.assertEqual(agent.epsilon_decay_period, 250000) self.assertEqual(agent._replay.memory._replay_capacity, 1000000) self.assertEqual(agent._replay.memory._batch_size, 32)
def testDefaultGinImplicitQuantileIcml(self): """Test default ImplicitQuantile configuration using ICML gin.""" tf.logging.info('###### Training the Implicit Quantile agent #####') self._base_dir = os.path.join( '/tmp/dopamine_tests', datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S')) tf.logging.info('###### IQN base dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/' 'implicit_quantile/configs/implicit_quantile_icml.gin'] gin_bindings = [ 'Runner.num_iterations=0', ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertEqual(1000000, runner._agent._replay.memory._replay_capacity) shutil.rmtree(self._base_dir)
def testOverrideGinDqn(self): """Test DQNAgent configuration overridden with AdamOptimizer.""" tf.logging.info('####### Training the DQN agent #####') tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/dqn/configs/dqn.gin'] gin_bindings = [ 'DQNAgent.optimizer = @tf.train.AdamOptimizer()', 'tf.train.AdamOptimizer.learning_rate = 100', 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'dqn'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.AdamOptimizer) self.assertEqual(100, runner._agent.optimizer._lr) shutil.rmtree(self._base_dir)
def testOverrideRunnerParams(self): """Test DQNAgent configuration using the default gin config.""" tf.logging.info('####### Training the DQN agent #####') tf.logging.info('####### DQN base_dir: {}'.format(self._base_dir)) gin_files = ['dopamine/agents/dqn/configs/dqn.gin'] gin_bindings = [ 'TrainRunner.base_dir = "{}"'.format(self._base_dir), 'Runner.log_every_n = 1729', 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'dqn'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.TrainRunner( create_agent_fn=run_experiment.create_agent, create_environment_fn=atari_lib.create_atari_environment) self.assertEqual(runner._base_dir, self._base_dir) self.assertEqual(runner._log_every_n, 1729) shutil.rmtree(self._base_dir)
def testOverrideGinRainbow(self): """Test RainbowAgent configuration overridden with RMSPropOptimizer.""" tf.logging.info('####### Training the RAINBOW agent #####') tf.logging.info('####### RAINBOW base_dir: {}'.format(self._base_dir)) gin_files = [ 'dopamine/agents/rainbow/configs/rainbow.gin', ] gin_bindings = [ 'RainbowAgent.optimizer = @tf.train.RMSPropOptimizer()', 'tf.train.RMSPropOptimizer.learning_rate = 100', 'WrappedReplayBuffer.replay_capacity = 100', # To prevent OOM. "create_agent.agent_name = 'rainbow'" ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertIsInstance(runner._agent.optimizer, tf.train.RMSPropOptimizer) self.assertEqual(100, runner._agent.optimizer._learning_rate) shutil.rmtree(self._base_dir)
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ logging.set_verbosity(logging.INFO) tf.compat.v1.disable_v2_behavior() base_dir = FLAGS.base_dir gin_files, gin_bindings = FLAGS.gin_files, FLAGS.gin_bindings run_experiment.load_gin_configs(gin_files, gin_bindings) # Set the Jax agent seed using the run number create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number) if FLAGS.max_episode_eval: runner_fn = eval_run_experiment.MaxEpisodeEvalRunner logging.info('Using MaxEpisodeEvalRunner for evaluation.') runner = runner_fn(base_dir, create_agent_fn) else: runner = run_experiment.Runner(base_dir, create_agent_fn) runner.run_experiment()
def testOverrideGinImplicitQuantile(self): """Test ImplicitQuantile configuration overriding using IQN gin.""" logging.info('###### Training the Implicit Quantile agent #####') self._base_dir = os.path.join( '/tmp/dopamine_tests', datetime.datetime.utcnow().strftime('run_%Y_%m_%d_%H_%M_%S')) logging.info('###### IQN base dir: %s', self._base_dir) gin_files = [ 'dopamine/agents/implicit_quantile/configs/' 'implicit_quantile.gin' ] gin_bindings = [ 'Runner.num_iterations=0', 'WrappedPrioritizedReplayBuffer.replay_capacity = 1000', ] run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.Runner(self._base_dir, run_experiment.create_agent, atari_lib.create_atari_environment) self.assertEqual(1000, runner._agent._replay.memory._replay_capacity) shutil.rmtree(self._base_dir)
def main(unused_argv): path, split = osp.split(FLAGS.exp_dir) path, game = osp.split(path) gin.bind_parameter('atari_lib.create_atari_environment.game_name', game) if FLAGS.use_preference_rewards: training_log_path = create_logs_for_training(FLAGS) agent_name = "_".join([ FLAGS.agent_name, FLAGS.preference_model_type, FLAGS.reward_model_type ]) FLAGS.replay_dir = training_log_path else: raise NotImplementedError tf.logging.set_verbosity(tf.logging.INFO) base_run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) replay_data_dir = os.path.join(FLAGS.replay_dir, 'replay_logs') create_agent_fn = functools.partial(create_agent, replay_data_dir=replay_data_dir) runner = run_experiment.FixedReplayRunner( osp.join(FLAGS.exp_dir, agent_name), create_agent_fn) runner.run_experiment() pack_agents(FLAGS)
def main(unused_argv): """Main method. Args: unused_argv: Arguments (unused). """ logging.set_verbosity(logging.INFO) xm_xid = None if 'xm_xid' not in FLAGS else FLAGS.xm_xid xm_wid = None if 'xm_wid' not in FLAGS else FLAGS.xm_wid xm_parameters = (None if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters) base_dir, gin_files, gin_bindings = base_train.run_xm_preprocessing( xm_xid, xm_wid, xm_parameters, FLAGS.base_dir, FLAGS.custom_base_dir_from_hparams, FLAGS.gin_files, FLAGS.gin_bindings) create_agent = functools.partial( create_offline_agent, agent_name=FLAGS.agent_name ) base_run_experiment.load_gin_configs(gin_files, gin_bindings) runner = run_experiment.FixedReplayRunner(base_dir, create_agent) runner.run_experiment()
def testDefaultRainbowConfig(self): """Verify the default Rainbow configuration.""" run_experiment.load_gin_configs( ['dopamine/agents/rainbow/configs/rainbow.gin'], []) agent = run_experiment.create_agent( self.test_session(), atari_lib.create_atari_environment(game_name='Pong')) self.assertEqual(agent._num_atoms, 51) support = self.evaluate(agent._support) self.assertEqual(min(support), -10.) self.assertEqual(max(support), 10.) self.assertEqual(len(support), 51) self.assertEqual(agent.gamma, 0.99) self.assertEqual(agent.update_horizon, 3) self.assertEqual(agent.min_replay_history, 20000) self.assertEqual(agent.update_period, 4) self.assertEqual(agent.target_update_period, 8000) self.assertEqual(agent.epsilon_train, 0.01) self.assertEqual(agent.epsilon_eval, 0.001) self.assertEqual(agent.epsilon_decay_period, 250000) self.assertEqual(agent._replay.memory._replay_capacity, 1000000) self.assertEqual(agent._replay.memory._batch_size, 32)
def main(configs): tf.logging.set_verbosity(tf.logging.INFO) base_run_experiment.load_gin_configs(configs.gin_files, configs.gin_bindings) replay_data_dir = os.path.join(configs.replay_dir, "replay_logs") create_agent_fn = functools.partial( create_agent, replay_data_dir=replay_data_dir, agent_name=configs.agent_name, init_checkpoint_dir=configs.init_checkpoint_dir, ) create_environment_fn = functools.partial(create_environment) runner = run_experiment.FixedReplayRunner( configs.base_dir, create_agent_fn, create_environment_fn=create_environment_fn ) dataset_path = os.path.join( os.path.realpath("."), "data/processed/v5_dataset/test_dataset_users/" ) chkpt_path = os.path.join( os.path.realpath("."), "models/reward_pred_v0_model/release/80_input" ) runner.set_offline_evaluation(dataset_path, chkpt_path) runner.run_experiment()
def main(_, hello='Hello'): # flag to use default config in bubble_agent.run. run_config = None # load gin configuration. if FLAGS.gin_files: from dopamine.discrete_domains import run_experiment print('! load gin-files:{}'.format(FLAGS.gin_files)) FLAGS.gin_bindings.append( 'retro_lib.create_retro_environment.level = {}'.format( FLAGS.level)) print('! gin-bindings={}'.format(FLAGS.gin_bindings)) run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) run_config = '' # do NOT load the default config in `bubble_agent.run` # run main. bubble_agent.run(agent=FLAGS.agent, game=FLAGS.game, level=FLAGS.level, num_steps=FLAGS.steps, root_dir=FLAGS.root_dir, restore_ckpt=FLAGS.restore_checkpoint, use_legacy_checkpoint=FLAGS.use_legacy_checkpoint, config=run_config)