Пример #1
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Пример #2
0
def make_demonstration_iterator(batch_size: int,
                                dataset_name: str,
                                seed: int = 0):
    dataset = tfds.get_tfds_dataset(dataset_name)
    return tfds.JaxInMemoryRandomSampleIterator(dataset,
                                                jax.random.PRNGKey(seed),
                                                batch_size)
Пример #3
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset.
    transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                                 FLAGS.num_demonstrations)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions_iterator,
        key=key_demonstrations,
        batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = cql.make_networks(environment_spec)

    # Create the learner.
    learner = cql.CQLLearner(
        batch_size=FLAGS.batch_size,
        networks=networks,
        random_key=key_learner,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        fixed_cql_coefficient=FLAGS.fixed_cql_coefficient,
        cql_lagrange_threshold=FLAGS.cql_lagrange_threshold,
        demonstrations=demonstrations,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = networks.policy_network.apply(params, observation)
        return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)