Esempio n. 1
0
def main(argv):
    del argv

    # Create the task environment.
    env_config = configs.get_fig4_task_config()
    env = scavenger.Scavenger(**env_config)
    env = environment_wrappers.EnvironmentWithLogging(env)

    # Create the flat agent.
    agent = dqn_agent.Agent(obs_spec=env.observation_spec(),
                            action_spec=env.action_spec(),
                            network_kwargs=dict(
                                output_sizes=(64, 128),
                                activate_final=True,
                            ),
                            epsilon=0.1,
                            additional_discount=0.9,
                            batch_size=10,
                            optimizer_name="AdamOptimizer",
                            optimizer_kwargs=dict(learning_rate=3e-4, ))

    _, ema_returns = experiment.run(env,
                                    agent,
                                    num_episodes=FLAGS.num_episodes,
                                    report_every=FLAGS.report_every)
    if FLAGS.output_path:
        experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
def main(argv):
    del argv

    # Load the keyboard.
    keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))

    # Create the task environment.
    base_env_config = configs.get_fig4_task_config()
    base_env = scavenger.Scavenger(**base_env_config)
    base_env = environment_wrappers.EnvironmentWithLogging(base_env)

    # Wrap the task environment with the keyboard.
    additional_discount = 0.9
    env = environment_wrappers.EnvironmentWithKeyboardDirect(
        env=base_env,
        keyboard=keyboard,
        keyboard_ckpt_path=None,
        additional_discount=additional_discount,
        call_and_return=False)

    # Create the player agent.
    agent = regressed_agent.Agent(
        batch_size=10,
        optimizer_name="AdamOptimizer",
        optimizer_kwargs=dict(learning_rate=3e-2, ),
        init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
    )

    _, ema_returns = experiment.run(env,
                                    agent,
                                    num_episodes=FLAGS.num_episodes,
                                    report_every=FLAGS.report_every,
                                    num_eval_reps=100)
    if FLAGS.output_path:
        experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
def main(argv):
  del argv

  # Load the keyboard.
  keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))

  # Create the task environment.
  base_env_config = configs.get_fig4_task_config()
  base_env = scavenger.Scavenger(**base_env_config)
  base_env = environment_wrappers.EnvironmentWithLogging(base_env)

  # Wrap the task environment with the keyboard.
  additional_discount = 0.9
  env = environment_wrappers.EnvironmentWithKeyboardDirect(
      env=base_env,
      keyboard=keyboard,
      keyboard_ckpt_path=None,
      additional_discount=additional_discount,
      call_and_return=False)

  # Create the player agent.
  agent = regressed_agent.Agent(
      batch_size=10,
      optimizer_name="AdamOptimizer",
      # Disable training.
      optimizer_kwargs=dict(learning_rate=0.0,),
      init_w=[1., -1.])

  returns = []
  for _ in range(FLAGS.num_episodes):
    returns.append(experiment.run_episode(env, agent))
  tf.logging.info("#" * 80)
  tf.logging.info(
      f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
  tf.logging.info("#" * 80)