예제 #1
0
def make_demonstrations(env: dm_env.Environment,
                        batch_size: int) -> tf.data.Dataset:
    """Prepare the dataset of demonstrations."""
    batch_dataset = bsuite_demonstrations.make_dataset(env, stochastic=False)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset
예제 #2
0
def main(_):
  # Create an environment and grab the spec.
  raw_environment = bsuite.load_and_record_to_csv(
      bsuite_id=FLAGS.bsuite_id,
      results_dir=FLAGS.results_dir,
      overwrite=FLAGS.overwrite,
  )
  environment = wrappers.SinglePrecisionWrapper(raw_environment)
  environment_spec = specs.make_environment_spec(environment)

  # Construct the agent.
  agent = dqfd.DQfD(
      environment_spec=environment_spec,
      network=make_network(environment_spec.actions),
      demonstration_dataset=bsuite_demonstrations.make_dataset(raw_environment),
      demonstration_ratio=FLAGS.demonstration_ratio,
      samples_per_insert=FLAGS.samples_per_insert,
      learning_rate=FLAGS.learning_rate)

  # Run the environment loop.
  loop = acme.EnvironmentLoop(environment, agent)
  loop.run(num_episodes=environment.bsuite_num_episodes)  # pytype: disable=attribute-error
예제 #3
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_and_record_to_csv(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment,
                                                       stochastic=False)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_network = snt.Sequential([
        policy_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_network, [environment_spec.observations])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 learning_rate=FLAGS.learning_rate,
                                 dataset=dataset,
                                 counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
예제 #4
0
파일: run_bc_jax.py 프로젝트: pchtsp/acme
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_and_record_to_csv(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = tfds.as_numpy(dataset)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)
    policy_network = hk.without_apply_rng(hk.transform(policy_network))

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        action_values = policy_network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 optimizer=optax.adam(FLAGS.learning_rate),
                                 obs_spec=environment.observation_spec(),
                                 dataset=dataset,
                                 counter=learner_counter,
                                 rng=hk.PRNGSequence(FLAGS.seed))

    # Create the actor which defines how we take actions.
    variable_client = variable_utils.VariableClient(learner, '')
    evaluator = actors.FeedForwardActor(evaluator_network,
                                        variable_client=variable_client,
                                        rng=hk.PRNGSequence(FLAGS.seed))

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)