Ejemplo n.º 1
0
    def test_discrete_actions(self, loss_name):
        with chex.fake_pmap_and_jit():

            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.DiscreteEnvironment(num_actions=10,
                                                    num_observations=100,
                                                    obs_shape=(10, ),
                                                    obs_dtype=np.float32)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec, discrete_actions=True)

            def logp_fn(logits, actions):
                max_logits = jnp.max(logits, axis=-1, keepdims=True)
                logits = logits - max_logits
                logits_actions = jnp.sum(
                    jax.nn.one_hot(actions, spec.actions.num_values) * logits,
                    axis=-1)

                log_prob = logits_actions - special.logsumexp(logits, axis=-1)
                return log_prob

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=logp_fn)

            elif loss_name == 'rcal':
                base_loss_fn = bc.logp(logp_fn=logp_fn)
                loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1)

            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
Ejemplo n.º 2
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Ejemplo n.º 3
0
    def test_continuous_actions(self, loss_name):
        with chex.fake_pmap_and_jit():
            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.ContinuousEnvironment(episode_length=10,
                                                      bounded=True,
                                                      action_dim=6)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec)

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                  dist_params.log_prob(actions))
            elif loss_name == 'mse':
                loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params
                                 .sample(seed=key))
            elif loss_name == 'peerbc':
                base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                       dist_params.log_prob(actions))
                loss_fn = bc.peerbc(base_loss_fn, zeta=0.1)
            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
Ejemplo n.º 4
0
def add_bc_pretraining(sac_networks: sac.SACNetworks) -> sac.SACNetworks:
    """Augments `sac_networks` to run BC pretraining in policy_network.init."""

    make_demonstrations = functools.partial(
        helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name)
    bc_network = bc.pretraining.convert_to_bc_network(
        sac_networks.policy_network)
    loss = bc.logp(sac_networks.log_prob)

    def bc_init(*unused_args):
        return bc.pretraining.train_with_bc(make_demonstrations, bc_network,
                                            loss)

    return dataclasses.replace(sac_networks,
                               policy_network=networks_lib.FeedForwardNetwork(
                                   bc_init, sac_networks.policy_network.apply))
Ejemplo n.º 5
0
def add_bc_pretraining(ppo_networks: ppo.PPONetworks) -> ppo.PPONetworks:
    """Augments `ppo_networks` to run BC pretraining in policy_network.init."""

    make_demonstrations = functools.partial(
        helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name)
    bc_network = bc.pretraining.convert_policy_value_to_bc_network(
        ppo_networks.network)
    loss = bc.logp(ppo_networks.log_prob)

    # Note: despite only training the policy network, this will also include the
    # initial value network params.
    def bc_init(*unused_args):
        return bc.pretraining.train_with_bc(make_demonstrations, bc_network,
                                            loss)

    return dataclasses.replace(ppo_networks,
                               network=networks_lib.FeedForwardNetwork(
                                   bc_init, ppo_networks.network.apply))