예제 #1
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
예제 #2
0
def make_demonstration_iterator(batch_size: int,
                                dataset_name: str,
                                seed: int = 0):
    dataset = tfds.get_tfds_dataset(dataset_name)
    return tfds.JaxInMemoryRandomSampleIterator(dataset,
                                                jax.random.PRNGKey(seed),
                                                batch_size)
예제 #3
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset.
    transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                                 FLAGS.num_demonstrations)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions_iterator,
        key=key_demonstrations,
        batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = cql.make_networks(environment_spec)

    # Create the learner.
    learner = cql.CQLLearner(
        batch_size=FLAGS.batch_size,
        networks=networks,
        random_key=key_learner,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        fixed_cql_coefficient=FLAGS.fixed_cql_coefficient,
        cql_lagrange_threshold=FLAGS.cql_lagrange_threshold,
        demonstrations=demonstrations,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = networks.policy_network.apply(params, observation)
        return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
예제 #4
0
def main(_):
    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=_ENV_NAME.value)
    spec = specs.make_environment_spec(environment)

    key = jax.random.PRNGKey(_SEED.value)
    key, dataset_key, evaluator_key = jax.random.split(key, 3)

    # Load the dataset.
    dataset = tensorflow_datasets.load(_DATASET_NAME.value)['train']
    # Unwrap the environment to get the demonstrations.
    dataset = mbop.episodes_to_timestep_batched_transitions(dataset,
                                                            return_horizon=10)
    dataset = tfds.JaxInMemoryRandomSampleIterator(
        dataset, key=dataset_key, batch_size=_BATCH_SIZE.value)

    # Apply normalization to the dataset.
    mean_std = mbop.get_normalization_stats(dataset,
                                            _NUM_NORMALIZATION_BATCHES.value)
    apply_normalization = jax.jit(
        functools.partial(running_statistics.normalize, mean_std=mean_std))
    dataset = (apply_normalization(sample) for sample in dataset)

    # Create the networks.
    networks = mbop.make_networks(spec,
                                  hidden_layer_sizes=tuple(
                                      _HIDDEN_LAYER_SIZES.value))

    # Use the default losses.
    losses = mbop.MBOPLosses()

    def logger_fn(label: str, steps_key: str):
        return loggers.make_default_logger(label, steps_key=steps_key)

    def make_learner(name, logger_fn, counter, rng_key, dataset, network,
                     loss):
        return mbop.make_ensemble_regressor_learner(
            name,
            _NUM_NETWORKS.value,
            logger_fn,
            counter,
            rng_key,
            dataset,
            network,
            loss,
            optax.adam(_LEARNING_RATE.value),
            _NUM_SGD_STEPS_PER_STEP.value,
        )

    learner = mbop.MBOPLearner(
        networks, losses, dataset, key, logger_fn,
        functools.partial(make_learner, 'world_model'),
        functools.partial(make_learner, 'policy_prior'),
        functools.partial(make_learner, 'n_step_return'))

    planning_config = mbop.MPPIConfig()

    assert planning_config.n_trajectories % _NUM_NETWORKS.value == 0, (
        'Number of trajectories must be a multiple of the number of networks.')

    actor_core = mbop.make_ensemble_actor_core(networks,
                                               planning_config,
                                               spec,
                                               mean_std,
                                               use_round_robin=False)
    evaluator = mbop.make_actor(actor_core, evaluator_key, learner)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Train the agent.
    while True:
        for _ in range(_EVALUATE_EVERY.value):
            learner.step()
        eval_loop.run(_EVALUATION_EPISODES.value)
예제 #5
0
 def make_demonstrations(batch_size):
   return acme_tfds.JaxInMemoryRandomSampleIterator(
       transitions_iterator, jax.random.PRNGKey(seed), batch_size)