コード例 #1
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
コード例 #2
0
    def test_discrete_actions(self, loss_name):
        with chex.fake_pmap_and_jit():

            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.DiscreteEnvironment(num_actions=10,
                                                    num_observations=100,
                                                    obs_shape=(10, ),
                                                    obs_dtype=np.float32)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec, discrete_actions=True)

            def logp_fn(logits, actions):
                max_logits = jnp.max(logits, axis=-1, keepdims=True)
                logits = logits - max_logits
                logits_actions = jnp.sum(
                    jax.nn.one_hot(actions, spec.actions.num_values) * logits,
                    axis=-1)

                log_prob = logits_actions - special.logsumexp(logits, axis=-1)
                return log_prob

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=logp_fn)

            elif loss_name == 'rcal':
                base_loss_fn = bc.logp(logp_fn=logp_fn)
                loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1)

            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
コード例 #3
0
def make_ensemble_regressor_learner(
    name: str,
    num_networks: int,
    logger_fn: loggers.LoggerFactory,
    counter: counting.Counter,
    rng_key: jnp.ndarray,
    iterator: Iterator[types.Transition],
    base_network: networks_lib.FeedForwardNetwork,
    loss: mbop_losses.TransitionLoss,
    optimizer: optax.GradientTransformation,
    num_sgd_steps_per_step: int,
):
  """Creates an ensemble regressor learner from the base network.

  Args:
    name: Name of the learner used for logging and counters.
    num_networks: Number of networks in the ensemble.
    logger_fn: Constructs a logger for a label.
    counter: Parent counter object.
    rng_key: Random key.
    iterator: An iterator of time-batched transitions used to train the
      networks.
    base_network: Base network for the ensemble.
    loss: Training loss to use.
    optimizer: Optax optimizer.
    num_sgd_steps_per_step: Number of gradient updates per step.

  Returns:
    An ensemble regressor learner.
  """
  mbop_ensemble = ensemble.make_ensemble(base_network, ensemble.apply_all,
                                         num_networks)
  local_counter = counting.Counter(parent=counter, prefix=name)
  local_logger = logger_fn(name,
                           local_counter.get_steps_key()) if logger_fn else None

  def loss_fn(apply_fn: Callable[..., networks_lib.NetworkOutput],
              params: networks_lib.Params, key: jnp.ndarray,
              transitions: types.Transition) -> jnp.ndarray:
    del key
    return loss(functools.partial(apply_fn, params), transitions)

  # This is effectively a regressor learner.
  return bc.BCLearner(
      mbop_ensemble,
      rng_key,
      loss_fn,
      optimizer,
      iterator,
      num_sgd_steps_per_step,
      logger=local_logger,
      counter=local_counter)
コード例 #4
0
    def test_continuous_actions(self, loss_name):
        with chex.fake_pmap_and_jit():
            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.ContinuousEnvironment(episode_length=10,
                                                      bounded=True,
                                                      action_dim=6)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec)

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                  dist_params.log_prob(actions))
            elif loss_name == 'mse':
                loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params
                                 .sample(seed=key))
            elif loss_name == 'peerbc':
                base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                       dist_params.log_prob(actions))
                loss_fn = bc.peerbc(base_loss_fn, zeta=0.1)
            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()