Ejemplo n.º 1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    env_descriptor = env.get_descriptor()

    mzconfig = agent_lib.muzeroconfig_from_flags(env_descriptor=env_descriptor)

    create_agent_fn = functools.partial(
        agent_lib.create_agent,
        agent_config=agent_lib.agent_config_from_flags())

    if FLAGS.run_mode == 'actor':
        logging.info('Make actor, %s/%s', FLAGS.task, FLAGS.num_envs)
        actor.actor_loop(
            functools.partial(env.create_environment,
                              stop_after_seeing_new_results=common_flags.
                              STOP_AFTER_SEEING_NEW_RESULTS.value > 0),
            mzconfig)
    elif FLAGS.run_mode == 'learner':
        learner.learner_loop(env_descriptor=env_descriptor,
                             create_agent_fn=create_agent_fn,
                             create_optimizer_fn=create_optimizer,
                             config=learner_flags.learner_config_from_flags(),
                             mzconfig=mzconfig)
    else:
        raise ValueError('Unsupported run mode {}'.format(FLAGS.run_mode))
Ejemplo n.º 2
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    def visit_softmax_temperature(num_moves, training_steps, is_training=True):
        if not is_training:
            return 0.
        if training_steps < 500e3 * 1024 / FLAGS.batch_size:
            return 1. * FLAGS.temperature
        elif training_steps < 750e3 * 1024 / FLAGS.batch_size:
            return 0.5 * FLAGS.temperature
        else:
            return 0.25 * FLAGS.temperature

    env_descriptor = env.get_descriptor()

    known_bounds = None
    mzconfig = mzcore.MuZeroConfig(
        action_space_size=env_descriptor.action_space.n,
        max_moves=27000,
        discount=1.0 - FLAGS.one_minus_discount,
        dirichlet_alpha=FLAGS.dirichlet_alpha,
        root_exploration_fraction=FLAGS.root_exploration_fraction,
        num_simulations=FLAGS.num_simulations,
        initial_inference_batch_size=(
            learner_flags.INITIAL_INFERENCE_BATCH_SIZE.value),
        recurrent_inference_batch_size=(
            learner_flags.RECURRENT_INFERENCE_BATCH_SIZE.value),
        train_batch_size=learner_flags.BATCH_SIZE.value,
        td_steps=FLAGS.td_steps,
        num_unroll_steps=FLAGS.num_unroll_steps,
        pb_c_base=FLAGS.pb_c_base,
        pb_c_init=FLAGS.pb_c_init,
        known_bounds=known_bounds,
        visit_softmax_temperature_fn=visit_softmax_temperature,
        use_softmax_for_action_selection=(
            FLAGS.use_softmax_for_action_selection == 1))

    if FLAGS.run_mode == 'actor':
        actor.actor_loop(env.create_environment, mzconfig)
    elif FLAGS.run_mode == 'learner':
        learner.learner_loop(env_descriptor,
                             create_agent,
                             create_optimizer,
                             learner_flags.learner_config_from_flags(),
                             mzconfig,
                             pretraining=(FLAGS.pretraining == 1))
    else:
        raise ValueError('Unsupported run mode {}'.format(FLAGS.run_mode))
Ejemplo n.º 3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    def visit_softmax_temperature(num_moves, training_steps, is_training=True):
        if not is_training:
            return 0.
        if FLAGS.play_max_after_moves < 0:
            return FLAGS.temperature
        if num_moves < FLAGS.play_max_after_moves:
            return FLAGS.temperature
        else:
            return 0.

    env_descriptor = env.get_descriptor()

    # Known bounds for Q-values have to include rewards and values.
    known_bounds = mzcore.KnownBounds(*map(
        sum, zip(env_descriptor.reward_range, env_descriptor.value_range)))
    mzconfig = mzcore.MuZeroConfig(
        action_space_size=env_descriptor.action_space.n,
        max_moves=env_descriptor.action_space.n,
        discount=1.0 - FLAGS.one_minus_discount,
        dirichlet_alpha=FLAGS.dirichlet_alpha,
        root_exploration_fraction=FLAGS.root_exploration_fraction,
        num_simulations=FLAGS.num_simulations,
        initial_inference_batch_size=(
            learner_flags.INITIAL_INFERENCE_BATCH_SIZE.value),
        recurrent_inference_batch_size=(
            learner_flags.RECURRENT_INFERENCE_BATCH_SIZE.value),
        train_batch_size=learner_flags.BATCH_SIZE.value,
        td_steps=FLAGS.td_steps,
        num_unroll_steps=FLAGS.num_unroll_steps,
        pb_c_base=FLAGS.pb_c_base,
        pb_c_init=FLAGS.pb_c_init,
        known_bounds=known_bounds,
        visit_softmax_temperature_fn=visit_softmax_temperature,
        use_softmax_for_action_selection=(
            FLAGS.use_softmax_for_action_selection == 1),
        max_num_action_expansion=FLAGS.max_num_action_expansion)

    if FLAGS.run_mode == 'actor':
        actor.actor_loop(env.create_environment, mzconfig)
    elif FLAGS.run_mode == 'learner':
        learner.learner_loop(env_descriptor, create_agent, create_optimizer,
                             learner_flags.learner_config_from_flags(),
                             mzconfig)
    else:
        raise ValueError('Unsupported run mode {}'.format(FLAGS.run_mode))