Пример #1
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_and_record_to_csv(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment = wrappers.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Create the networks to optimize.
    network = make_network(environment_spec.actions)

    agent = impala.IMPALA(
        environment_spec=environment_spec,
        network=network,
        sequence_length=3,
        sequence_period=3,
    )

    # Run the environment loop.
    loop = acme.EnvironmentLoop(environment, agent)
    loop.run(num_episodes=environment.bsuite_num_episodes)  # pytype: disable=attribute-error
Пример #2
0
    def test_ddpg(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks to optimize (online) and target networks.
        agent_networks = make_networks(spec.actions)

        # Construct the agent.
        agent = ddpg.DDPG(
            environment_spec=spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
Пример #3
0
def main(_):
    env = helpers.make_environment(level=FLAGS.level, oar_wrapper=True)
    env_spec = acme.make_environment_spec(env)

    config = impala.IMPALAConfig(
        batch_size=16,
        sequence_period=10,
        seed=FLAGS.seed,
    )

    networks = impala.make_atari_networks(env_spec)
    agent = impala.IMPALAFromConfig(
        environment_spec=env_spec,
        forward_fn=networks.forward_fn,
        unroll_init_fn=networks.unroll_init_fn,
        unroll_fn=networks.unroll_fn,
        initial_state_init_fn=networks.initial_state_init_fn,
        initial_state_fn=networks.initial_state_fn,
        config=config,
    )

    loop = acme.EnvironmentLoop(env, agent)
    loop.run(FLAGS.num_episodes)
Пример #4
0
  def test_impala(self):
    # Create a fake environment to test with.
    environment = fakes.DiscreteEnvironment(
        num_actions=5,
        num_observations=10,
        obs_shape=(10, 5),
        obs_dtype=np.float32,
        episode_length=10)
    spec = specs.make_environment_spec(environment)

    def forward_fn(x, s):
      model = MyNetwork(spec.actions.num_values)
      return model(x, s)

    def initial_state_fn(batch_size: Optional[int] = None):
      model = MyNetwork(spec.actions.num_values)
      return model.initial_state(batch_size)

    def unroll_fn(inputs, state):
      model = MyNetwork(spec.actions.num_values)
      return hk.static_unroll(model, inputs, state)

    # Construct the agent.
    agent = impala.IMPALA(
        environment_spec=spec,
        forward_fn=forward_fn,
        initial_state_fn=initial_state_fn,
        unroll_fn=unroll_fn,
        sequence_length=3,
        sequence_period=3,
        batch_size=6,
    )

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent)
    loop.run(num_episodes=20)
Пример #5
0
    def test_td3(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        network = td3.make_networks(spec)

        config = td3.TD3Config(batch_size=10, min_replay_size=1)

        counter = counting.Counter()
        agent = td3.TD3(spec=spec,
                        network=network,
                        config=config,
                        seed=0,
                        counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
Пример #6
0
  def actor(
      self,
      replay: reverb.Client,
      variable_source: acme.VariableSource,
      counter: counting.Counter,
  ) -> acme.EnvironmentLoop:
    """The actor process."""

    # Create the behavior policy.
    networks = self._network_factory(self._environment_spec.actions)
    networks.init(self._environment_spec)
    policy_network = networks.make_policy(
        environment_spec=self._environment_spec,
        sigma=self._sigma,
    )

    # Create the agent.
    actor = self._builder.make_actor(
        policy_network=policy_network,
        adder=self._builder.make_adder(replay),
        variable_source=variable_source,
    )

    # Create the environment.
    environment = self._environment_factory(False)

    # Create logger and counter; actors will not spam bigtable.
    counter = counting.Counter(counter, 'actor')
    logger = loggers.make_default_logger(
        'actor',
        save_data=False,
        time_delta=self._log_every,
        steps_key='actor_steps')

    # Create the loop to connect environment and agent.
    return acme.EnvironmentLoop(environment, actor, counter, logger)
Пример #7
0
def main(_):
    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=_ENV_NAME.value)
    spec = specs.make_environment_spec(environment)

    key = jax.random.PRNGKey(_SEED.value)
    key, dataset_key, evaluator_key = jax.random.split(key, 3)

    # Load the dataset.
    dataset = tensorflow_datasets.load(_DATASET_NAME.value)['train']
    # Unwrap the environment to get the demonstrations.
    dataset = mbop.episodes_to_timestep_batched_transitions(dataset,
                                                            return_horizon=10)
    dataset = tfds.JaxInMemoryRandomSampleIterator(
        dataset, key=dataset_key, batch_size=_BATCH_SIZE.value)

    # Apply normalization to the dataset.
    mean_std = mbop.get_normalization_stats(dataset,
                                            _NUM_NORMALIZATION_BATCHES.value)
    apply_normalization = jax.jit(
        functools.partial(running_statistics.normalize, mean_std=mean_std))
    dataset = (apply_normalization(sample) for sample in dataset)

    # Create the networks.
    networks = mbop.make_networks(spec,
                                  hidden_layer_sizes=tuple(
                                      _HIDDEN_LAYER_SIZES.value))

    # Use the default losses.
    losses = mbop.MBOPLosses()

    def logger_fn(label: str, steps_key: str):
        return loggers.make_default_logger(label, steps_key=steps_key)

    def make_learner(name, logger_fn, counter, rng_key, dataset, network,
                     loss):
        return mbop.make_ensemble_regressor_learner(
            name,
            _NUM_NETWORKS.value,
            logger_fn,
            counter,
            rng_key,
            dataset,
            network,
            loss,
            optax.adam(_LEARNING_RATE.value),
            _NUM_SGD_STEPS_PER_STEP.value,
        )

    learner = mbop.MBOPLearner(
        networks, losses, dataset, key, logger_fn,
        functools.partial(make_learner, 'world_model'),
        functools.partial(make_learner, 'policy_prior'),
        functools.partial(make_learner, 'n_step_return'))

    planning_config = mbop.MPPIConfig()

    assert planning_config.n_trajectories % _NUM_NETWORKS.value == 0, (
        'Number of trajectories must be a multiple of the number of networks.')

    actor_core = mbop.make_ensemble_actor_core(networks,
                                               planning_config,
                                               spec,
                                               mean_std,
                                               use_round_robin=False)
    evaluator = mbop.make_actor(actor_core, evaluator_key, learner)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Train the agent.
    while True:
        for _ in range(_EVALUATE_EVERY.value):
            learner.step()
        eval_loop.run(_EVALUATION_EPISODES.value)
Пример #8
0
def train_and_evaluate(distance_fn, rng):
    """Train a policy on the learned distance function and evaluate task success.

  Args:
    distance_fn: function mapping a (state, goal)-pair to a state embedding and
        a distance estimate used for policy learning.
    rng: random key used to initialize evaluation actor.
  """
    goal_image = load_goal_image(FLAGS.robot_data_path)
    logdir = FLAGS.logdir
    video_dir = paths.process_path(logdir, 'videos')
    print('Writing videos to', video_dir)
    counter = counting.Counter()
    eval_counter = counting.Counter(counter, prefix='eval', time_delta=0.0)
    # Include training episodes and steps and walltime in the first eval logs.
    counter.increment(episodes=0, steps=0, walltime=0)

    environment = make_environment(
        task=FLAGS.task,
        end_on_success=FLAGS.end_on_success,
        max_episode_steps=FLAGS.max_episode_steps,
        distance_fn=distance_fn,
        goal_image=goal_image,
        baseline_distance=FLAGS.baseline_distance,
        logdir=video_dir,
        counter=counter,
        record_every=FLAGS.record_episodes_frequency,
        num_episodes_to_record=FLAGS.num_episodes_to_record)
    environment_spec = specs.make_environment_spec(environment)
    print('Environment spec')
    print(environment_spec)
    agent_networks = sac.make_networks(environment_spec)

    config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=FLAGS.min_replay_size)
    agent = deprecated_sac.SAC(environment_spec,
                               agent_networks,
                               config=config,
                               counter=counter,
                               seed=FLAGS.seed)

    env_logger = loggers.CSVLogger(logdir, 'env_loop', flush_every=5)
    eval_env_logger = loggers.CSVLogger(logdir, 'eval_env_loop', flush_every=1)
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      label='train_loop',
                                      logger=env_logger,
                                      counter=counter)

    eval_actor = agent.builder.make_actor(random_key=rng,
                                          policy=sac.apply_policy_and_sample(
                                              agent_networks, eval_mode=True),
                                          environment_spec=environment_spec,
                                          variable_source=agent)

    eval_video_dir = paths.process_path(logdir, 'eval_videos')
    print('Writing eval videos to', eval_video_dir)
    if FLAGS.baseline_distance_from_goal_to_goal:
        state = goal_image
        if distance_fn.history_length > 1:
            state = np.stack([goal_image] * distance_fn.history_length,
                             axis=-1)
        unused_embeddings, baseline_distance = distance_fn(state, goal_image)
        print('Baseline prediction', baseline_distance)
    else:
        baseline_distance = FLAGS.baseline_distance
    eval_env = make_environment(task=FLAGS.task,
                                end_on_success=False,
                                max_episode_steps=FLAGS.max_episode_steps,
                                distance_fn=distance_fn,
                                goal_image=goal_image,
                                eval_mode=True,
                                logdir=eval_video_dir,
                                counter=eval_counter,
                                record_every=FLAGS.num_eval_episodes,
                                num_episodes_to_record=FLAGS.num_eval_episodes,
                                baseline_distance=baseline_distance)

    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     label='eval_loop',
                                     logger=eval_env_logger,
                                     counter=eval_counter)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
Пример #9
0
def run_experiment(experiment: config.ExperimentConfig,
                   eval_every: int = 100,
                   num_eval_episodes: int = 1):
  """Runs a simple, single-threaded training loop using the default evaluators.

  It targets simplicity of the code and so only the basic features of the
  ExperimentConfig are supported.

  Arguments:
    experiment: Definition and configuration of the agent to run.
    eval_every: After how many actor steps to perform evaluation.
    num_eval_episodes: How many evaluation episodes to execute at each
      evaluation step.
  """

  key = jax.random.PRNGKey(experiment.seed)

  # Create the environment and get its spec.
  environment = experiment.environment_factory(experiment.seed)
  environment_spec = experiment.environment_spec or specs.make_environment_spec(
      environment)

  # Create the networks and policy.
  networks = experiment.network_factory(environment_spec)
  policy = config.make_policy(
      experiment=experiment,
      networks=networks,
      environment_spec=environment_spec,
      evaluation=False)

  # Create the replay server and grab its address.
  replay_tables = experiment.builder.make_replay_tables(environment_spec,
                                                        policy)

  # Disable blocking of inserts by tables' rate limiters, as this function
  # executes learning (sampling from the table) and data generation
  # (inserting into the table) sequentially from the same thread
  # which could result in blocked insert making the algorithm hang.
  replay_tables, rate_limiters_max_diff = _disable_insert_blocking(
      replay_tables)

  replay_server = reverb.Server(replay_tables, port=None)
  replay_client = reverb.Client(f'localhost:{replay_server.port}')

  # Parent counter allows to share step counts between train and eval loops and
  # the learner, so that it is possible to plot for example evaluator's return
  # value as a function of the number of training episodes.
  parent_counter = counting.Counter(time_delta=0.)

  # Create actor, and learner for generating, storing, and consuming
  # data respectively.
  dataset = experiment.builder.make_dataset_iterator(replay_client)
  # We always use prefetch, as it provides an iterator with additional
  # 'ready' method.
  dataset = utils.prefetch(dataset, buffer_size=1)
  learner_key, key = jax.random.split(key)
  learner = experiment.builder.make_learner(
      random_key=learner_key,
      networks=networks,
      dataset=dataset,
      logger_fn=experiment.logger_factory,
      environment_spec=environment_spec,
      replay_client=replay_client,
      counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.))

  adder = experiment.builder.make_adder(replay_client, environment_spec, policy)
  actor_key, key = jax.random.split(key)
  actor = experiment.builder.make_actor(
      actor_key, policy, environment_spec, variable_source=learner, adder=adder)

  # Create the environment loop used for training.
  train_counter = counting.Counter(
      parent_counter, prefix='train', time_delta=0.)
  train_logger = experiment.logger_factory('train',
                                           train_counter.get_steps_key(), 0)

  # Replace the actor with a LearningActor. This makes sure that every time
  # that `update` is called on the actor it checks to see whether there is
  # any new data to learn from and if so it runs a learner step. The rate
  # at which new data is released is controlled by the replay table's
  # rate_limiter which is created by the builder.make_replay_tables call above.
  actor = _LearningActor(actor, learner, dataset, replay_tables,
                         rate_limiters_max_diff)

  train_loop = acme.EnvironmentLoop(
      environment,
      actor,
      counter=train_counter,
      logger=train_logger,
      observers=experiment.observers)

  if num_eval_episodes == 0:
    # No evaluation. Just run the training loop.
    train_loop.run(num_steps=experiment.max_num_actor_steps)
    return

  # Create the evaluation actor and loop.
  eval_counter = counting.Counter(parent_counter, prefix='eval', time_delta=0.)
  eval_logger = experiment.logger_factory('eval', eval_counter.get_steps_key(),
                                          0)
  eval_policy = config.make_policy(
      experiment=experiment,
      networks=networks,
      environment_spec=environment_spec,
      evaluation=True)
  eval_actor = experiment.builder.make_actor(
      random_key=jax.random.PRNGKey(experiment.seed),
      policy=eval_policy,
      environment_spec=environment_spec,
      variable_source=learner)
  eval_loop = acme.EnvironmentLoop(
      environment,
      eval_actor,
      counter=eval_counter,
      logger=eval_logger,
      observers=experiment.observers)

  steps = 0
  while steps < experiment.max_num_actor_steps:
    eval_loop.run(num_episodes=num_eval_episodes)
    steps += train_loop.run(num_steps=eval_every)
  eval_loop.run(num_episodes=num_eval_episodes)
Пример #10
0
def main(_):
    # TODO(yutian): Create environment.
    # # Create an environment and grab the spec.
    # raw_environment = bsuite.load_and_record_to_csv(
    #     bsuite_id=FLAGS.bsuite_id,
    #     results_dir=FLAGS.results_dir,
    #     overwrite=FLAGS.overwrite,
    # )
    # environment = single_precision.SinglePrecisionWrapper(raw_environment)
    # environment_spec = specs.make_environment_spec(environment)

    # TODO(yutian): Create dataset.
    # Build the dataset.
    # if hasattr(raw_environment, 'raw_env'):
    #   raw_environment = raw_environment.raw_env
    #
    # batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # # Combine with demonstration dataset.
    # transition = functools.partial(
    #     _n_step_transition_from_episode, n_step=1, additional_discount=1.)
    #
    # dataset = batch_dataset.map(transition)
    #
    # # Batch and prefetch.
    # dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    # dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    networks = make_networks(environment_spec.actions)
    treatment_net = networks['treatment_net']
    instrumental_net = networks['instrumental_net']
    policy_net = networks['policy_net']

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_net = snt.Sequential([
        policy_net,
        # Sample actions.
        acme_nets.StochasticSamplingHead()
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_net, [environment_spec.observations])
    # TODO(liyuan): set the proper input spec using environment_spec.observations
    # and environment_spec.actions.
    tf2_utils.create_variables(treatment_net, [environment_spec.observations])
    tf2_utils.create_variables(
        instrumental_net,
        [environment_spec.observations, environment_spec.actions])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluator_net = actors.FeedForwardActor(evaluator_net)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator_net,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.DFIVLearner(
        treatment_net=treatment_net,
        instrumental_net=instrumental_net,
        policy_net=policy_net,
        treatment_learning_rate=FLAGS.treatment_learning_rate,
        instrumental_learning_rate=FLAGS.instrumental_learning_rate,
        policy_learning_rate=FLAGS.policy_learning_rate,
        dataset=dataset,
        counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Пример #11
0
def main(_):
    # Create an environment, grab the spec.
    environment = utils.make_environment(task=FLAGS.env_name)
    aqua_config = config.AquademConfig()
    spec = specs.make_environment_spec(environment)
    discretized_spec = aquadem_builder.discretize_spec(spec,
                                                       aqua_config.num_actions)

    # Create AQuaDem builder.
    loss_fn = dqn.losses.MunchausenQLearning(max_abs_reward=100.)
    dqn_config = dqn.DQNConfig(min_replay_size=1000,
                               n_step=3,
                               num_sgd_steps_per_step=8,
                               learning_rate=1e-4,
                               samples_per_insert=256)
    rl_agent = dqn.DQNBuilder(config=dqn_config, loss_fn=loss_fn)
    make_demonstrations = utils.get_make_demonstrations_fn(
        FLAGS.env_name, FLAGS.num_demonstrations, FLAGS.seed)
    builder = aquadem_builder.AquademBuilder(
        rl_agent=rl_agent,
        config=aqua_config,
        make_demonstrations=make_demonstrations)

    # Create networks.
    q_network = aquadem_networks.make_q_network(spec=discretized_spec, )
    dqn_networks = dqn.DQNNetworks(
        policy_network=networks_lib.non_stochastic_network_to_typed(q_network))
    networks = aquadem_networks.make_action_candidates_network(
        spec=spec,
        num_actions=aqua_config.num_actions,
        discrete_rl_networks=dqn_networks)
    exploration_epsilon = 0.01
    discrete_policy = dqn.default_behavior_policy(dqn_networks,
                                                  exploration_epsilon)
    behavior_policy = aquadem_builder.get_aquadem_policy(
        discrete_policy, networks)

    # Create the environment loop used for training.
    agent = local_layout.LocalLayout(seed=FLAGS.seed,
                                     environment_spec=spec,
                                     builder=builder,
                                     networks=networks,
                                     policy_network=behavior_policy,
                                     batch_size=dqn_config.batch_size *
                                     dqn_config.num_sgd_steps_per_step)

    train_logger = loggers.CSVLogger(FLAGS.workdir, label='train')
    train_loop = acme.EnvironmentLoop(environment, agent, logger=train_logger)

    # Create the evaluation actor and loop.
    eval_policy = dqn.default_behavior_policy(dqn_networks, 0.)
    eval_policy = aquadem_builder.get_aquadem_policy(eval_policy, networks)
    eval_actor = builder.make_actor(random_key=jax.random.PRNGKey(FLAGS.seed),
                                    policy=eval_policy,
                                    environment_spec=spec,
                                    variable_source=agent)
    eval_env = utils.make_environment(task=FLAGS.env_name, evaluation=True)

    eval_logger = loggers.CSVLogger(FLAGS.workdir, label='eval')
    eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=10)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=10)
Пример #12
0
    def test_ail_flax(self):
        shutil.rmtree(flags.FLAGS.test_tmpdir)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        environment = fakes.ContinuousEnvironment(
            episode_length=EPISODE_LENGTH,
            action_dim=CONTINUOUS_ACTION_DIM,
            observation_dim=CONTINUOUS_OBS_DIM,
            bounded=True)
        spec = specs.make_environment_spec(environment)

        networks = sac.make_networks(spec=spec)
        config = sac.SACConfig(batch_size=batch_size,
                               samples_per_insert_tolerance_rate=float('inf'),
                               min_replay_size=1)
        base_builder = sac.SACBuilder(config=config)
        direct_rl_batch_size = batch_size
        behavior_policy = sac.apply_policy_and_sample(networks)

        discriminator_module = DiscriminatorModule(spec, linen.Dense(1))

        def apply_fn(params: networks_lib.Params,
                     policy_params: networks_lib.Params,
                     state: networks_lib.Params, transitions: types.Transition,
                     is_training: bool,
                     rng: networks_lib.PRNGKey) -> networks_lib.Logits:
            del policy_params
            variables = dict(params=params, **state)
            return discriminator_module.apply(variables,
                                              transitions.observation,
                                              transitions.action,
                                              transitions.next_observation,
                                              is_training=is_training,
                                              rng=rng,
                                              mutable=state.keys())

        def init_fn(rng):
            variables = discriminator_module.init(rng,
                                                  dummy_obs,
                                                  dummy_actions,
                                                  dummy_obs,
                                                  is_training=False,
                                                  rng=rng)
            init_state, discriminator_params = variables.pop('params')
            return discriminator_params, init_state

        dummy_obs = utils.zeros_like(spec.observations)
        dummy_obs = utils.add_batch_dim(dummy_obs)
        dummy_actions = utils.zeros_like(spec.actions)
        dummy_actions = utils.add_batch_dim(dummy_actions)
        discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn,
                                                                apply=apply_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(is_sequence_based=False,
                                 share_iterator=True,
                                 direct_rl_batch_size=direct_rl_batch_size,
                                 discriminator_batch_size=2,
                                 policy_variable_name=None,
                                 min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        counter = counting.Counter()
        # Construct the agent.
        agent = local_layout.LocalLayout(
            seed=0,
            environment_spec=spec,
            builder=builder,
            networks=networks,
            policy_network=behavior_policy,
            min_replay_size=1,
            batch_size=batch_size,
            counter=counter,
        )

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        train_loop.run(num_episodes=1)
Пример #13
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_and_record_to_csv(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment,
                                                       stochastic=False)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_network = snt.Sequential([
        policy_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_network, [environment_spec.observations])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 learning_rate=FLAGS.learning_rate,
                                 dataset=dataset,
                                 counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
def pong_experiment():
    seeds = [0, 42, 69, 360, 420]
    seed = seeds[0]
    start_time = time.strftime("%Y-%m-%d_%H-%M-%S")

    # setting torch random seed here to make sure random initialization is the same
    torch.random.manual_seed(seed)

    # creating the environment
    env_name = "PongNoFrameskip-v4"

    env_train = make_environment_atari(env_name, seed)
    env_test = make_environment_atari(env_name, seed)

    env_train_spec = acme.make_environment_spec(env_train)

    # creating the neural network
    network = PositionNetworkSingleHead(
        env_train_spec.observations[0].shape,
        env_train_spec.observations[1].shape[0] *
        env_train_spec.actions[1].num_values,
        env_train_spec.actions[0].num_values *
        env_train_spec.actions[1].num_values)

    # creating the logger
    training_logger = TensorBoardLogger("runs/DQN-train-" + env_name +
                                        f"-rnd{seed}-" + start_time)
    testing_logger = TensorBoardLogger("runs/DQN-test-" + env_name +
                                       f"-rnd{seed}-" + start_time)

    # creating the agent
    agent = VanillaPartialDQN(network, [
        env_train_spec.actions[0].num_values,
        env_train_spec.actions[1].num_values
    ],
                              training_logger,
                              gradient_clipping=True,
                              device='gpu',
                              seed=seed,
                              replay_start_size=100)

    training_loop = acme.EnvironmentLoop(env_train,
                                         agent,
                                         logger=training_logger)
    testing_loop = acme.EnvironmentLoop(env_test,
                                        agent,
                                        logger=testing_logger,
                                        should_update=False)

    for epoch in range(200):
        agent.training()
        training_loop.run(num_steps=250000)

        torch.save(
            network.state_dict(), "runs/DQN-train-" + env_name +
            f"-rnd{seed}-" + start_time + f"/ep{epoch}.model")

        agent.testing()
        testing_loop.run(num_episodes=30)

    training_logger.close()
    testing_logger.close()
Пример #15
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = crr.make_networks(environment_spec)

    # CRR policy loss function.
    policy_loss_coeff_fn = crr.policy_loss_coeff_advantage_exp

    # Create the learner.
    learner = crr.CRRLearner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        target_update_period=FLAGS.target_update_period,
        policy_loss_coeff_fn=policy_loss_coeff_fn,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        grad_updates_per_batch=FLAGS.grad_updates_per_batch,
        use_sarsa_target=FLAGS.use_sarsa_target)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = networks.policy_network.apply(params, observation)
        return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Пример #16
0
  def actor(
      self,
      replay: reverb.Client,
      variable_source: acme.VariableSource,
      counter: counting.Counter,
      actor_id: int,
  ) -> acme.EnvironmentLoop:
    """The actor process."""

    action_spec = self._environment_spec.actions
    observation_spec = self._environment_spec.observations

    # Create environment and target networks to act with.
    environment = self._environment_factory(False)
    agent_networks = self._network_factory(action_spec)

    # Make sure observation network is defined.
    observation_network = agent_networks.get('observation', tf.identity)

    # Create a stochastic behavior policy.
    behavior_network = snt.Sequential([
        observation_network,
        agent_networks['policy'],
        networks.StochasticSamplingHead(),
    ])

    # Ensure network variables are created.
    tf2_utils.create_variables(behavior_network, [observation_spec])
    policy_variables = {'policy': behavior_network.variables}

    # Create the variable client responsible for keeping the actor up-to-date.
    variable_client = tf2_variable_utils.VariableClient(
        variable_source,
        policy_variables,
        update_period=self._variable_update_period)

    # Make sure not to use a random policy after checkpoint restoration by
    # assigning variables before running the environment loop.
    variable_client.update_and_wait()

    # Component to add things into replay.
    adder = adders.NStepTransitionAdder(
        client=replay,
        n_step=self._n_step,
        discount=self._additional_discount)

    # Create the agent.
    actor = actors.FeedForwardActor(
        policy_network=behavior_network,
        adder=adder,
        variable_client=variable_client)

    # Create logger and counter; only the first actor stores logs to bigtable.
    save_data = actor_id == 0
    counter = counting.Counter(counter, 'actor')
    logger = loggers.make_default_logger(
        'actor',
        save_data=save_data,
        time_delta=self._log_every,
        steps_key='actor_steps')
    observers = self._make_observers() if self._make_observers else ()

    # Create the run loop and return it.
    return acme.EnvironmentLoop(
        environment, actor, counter, logger, observers=observers)
Пример #17
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_and_record_to_csv(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = tfds.as_numpy(dataset)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)
    policy_network = hk.without_apply_rng(hk.transform(policy_network))

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        action_values = policy_network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 optimizer=optax.adam(FLAGS.learning_rate),
                                 obs_spec=environment.observation_spec(),
                                 dataset=dataset,
                                 counter=learner_counter,
                                 rng=hk.PRNGSequence(FLAGS.seed))

    # Create the actor which defines how we take actions.
    variable_client = variable_utils.VariableClient(learner, '')
    evaluator = actors.FeedForwardActor(evaluator_network,
                                        variable_client=variable_client,
                                        rng=hk.PRNGSequence(FLAGS.seed))

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Пример #18
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = ppo.make_continuous_networks(environment_spec)

    # Construct the agent.
    ppo_config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length,
                               num_minibatches=FLAGS.ppo_num_minibatches,
                               num_epochs=FLAGS.ppo_num_epochs,
                               batch_size=FLAGS.transition_batch_size //
                               FLAGS.unroll_length,
                               learning_rate=0.0003,
                               entropy_cost=0,
                               gae_lambda=0.8,
                               value_cost=0.25)
    ppo_networks = ppo.make_continuous_networks(environment_spec)
    if FLAGS.pretrain:
        ppo_networks = add_bc_pretraining(ppo_networks)

    discriminator_batch_size = FLAGS.transition_batch_size
    ail_config = ail.AILConfig(
        direct_rl_batch_size=ppo_config.batch_size * ppo_config.unroll_length,
        discriminator_batch_size=discriminator_batch_size,
        is_sequence_based=True,
        num_sgd_steps_per_step=FLAGS.num_discriminator_steps_per_step,
        share_iterator=FLAGS.share_iterator,
    )

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        # Note: observation embedding is not needed for e.g. Mujoco.
        return ail.DiscriminatorModule(
            environment_spec=environment_spec,
            use_action=True,
            use_next_obs=True,
            network_core=ail.DiscriminatorMLP([4, 4], ),
        )(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=ppo_networks)

    agent = ail.GAIL(spec=environment_spec,
                     network=ail_network,
                     config=ail.GAILConfig(ail_config, ppo_config),
                     seed=FLAGS.seed,
                     batch_size=ppo_config.batch_size,
                     make_demonstrations=functools.partial(
                         helpers.make_demonstration_iterator,
                         dataset_name=FLAGS.dataset_name),
                     policy_network=ppo.make_inference_fn(ppo_networks))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=ppo.make_inference_fn(agent_networks, evaluation=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Пример #19
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Construct the agent.
    # Local layout makes sure that we populate the buffer with min_replay_size
    # initial transitions and that there's no need for tolerance_rate. In order
    # for deadlocks not to happen we need to disable rate limiting that heppens
    # inside the TD3Builder. This is achieved by the min_replay_size and
    # samples_per_insert_tolerance_rate arguments.
    td3_config = td3.TD3Config(
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=1,
        samples_per_insert_tolerance_rate=float('inf'))
    td3_networks = td3.make_networks(environment_spec)
    if FLAGS.pretrain:
        td3_networks = add_bc_pretraining(td3_networks)

    ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size *
                               td3_config.num_sgd_steps_per_step)
    dac_config = ail.DACConfig(ail_config, td3_config)

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        return ail.DiscriminatorModule(environment_spec=environment_spec,
                                       use_action=True,
                                       use_next_obs=True,
                                       network_core=ail.DiscriminatorMLP(
                                           [4, 4], ))(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=td3_networks)

    agent = ail.DAC(spec=environment_spec,
                    network=ail_network,
                    config=dac_config,
                    seed=FLAGS.seed,
                    batch_size=td3_config.batch_size *
                    td3_config.num_sgd_steps_per_step,
                    make_demonstrations=functools.partial(
                        helpers.make_demonstration_iterator,
                        dataset_name=FLAGS.dataset_name),
                    policy_network=td3.get_default_behavior_policy(
                        td3_networks,
                        action_specs=environment_spec.actions,
                        sigma=td3_config.sigma))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    # TODO(lukstafi): sigma=0 for eval?
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            td3_networks, action_specs=environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Пример #20
0
def run_offline_experiment(experiment: config.OfflineExperimentConfig,
                           eval_every: int = 100,
                           num_eval_episodes: int = 1):
    """Runs a simple, single-threaded training loop using the default evaluators.

  It targets simplicity of the code and so only the basic features of the
  OfflineExperimentConfig are supported.

  Arguments:
    experiment: Definition and configuration of the agent to run.
    eval_every: After how many learner steps to perform evaluation.
    num_eval_episodes: How many evaluation episodes to execute at each
      evaluation step.
  """

    key = jax.random.PRNGKey(experiment.seed)

    # Create the environment and get its spec.
    environment = experiment.environment_factory(experiment.seed)
    environment_spec = experiment.environment_spec or specs.make_environment_spec(
        environment)

    # Create the networks and policy.
    networks = experiment.network_factory(environment_spec)

    # Parent counter allows to share step counts between train and eval loops and
    # the learner, so that it is possible to plot for example evaluator's return
    # value as a function of the number of training episodes.
    parent_counter = counting.Counter(time_delta=0.)

    # Create the demonstrations dataset.
    dataset_key, key = jax.random.split(key)
    dataset = experiment.demonstration_dataset_factory(dataset_key)

    # Create the learner.
    learner_key, key = jax.random.split(key)
    learner = experiment.builder.make_learner(
        random_key=learner_key,
        networks=networks,
        dataset=dataset,
        logger_fn=experiment.logger_factory,
        environment_spec=environment_spec,
        counter=counting.Counter(parent_counter,
                                 prefix='learner',
                                 time_delta=0.))

    # Define the evaluation loop.
    eval_loop = None
    if num_eval_episodes > 0:
        # Create the evaluation actor and loop.
        eval_logger = experiment.logger_factory('eval', 'eval_steps', 0)
        eval_key, key = jax.random.split(key)
        eval_actor = experiment.builder.make_actor(
            random_key=eval_key,
            policy=experiment.builder.make_policy(networks, environment_spec,
                                                  True),
            environment_spec=environment_spec,
            variable_source=learner)
        eval_loop = acme.EnvironmentLoop(environment,
                                         eval_actor,
                                         counter=counting.Counter(
                                             parent_counter,
                                             prefix='eval',
                                             time_delta=0.),
                                         logger=eval_logger,
                                         observers=experiment.observers)

    # Run the training loop.
    if eval_loop:
        eval_loop.run(num_eval_episodes)
    steps = 0
    while steps < experiment.max_num_learner_steps:
        learner_steps = min(eval_every,
                            experiment.max_num_learner_steps - steps)
        for _ in range(learner_steps):
            learner.step()
        if eval_loop:
            eval_loop.run(num_eval_episodes)
        steps += learner_steps
Пример #21
0
from acme.utils import loggers
from acme.wrappers import gym_wrapper

from agents.dqn_agent import DQNAgent
from networks.models import Models

from tensorflow.python.client import device_lib

print(device_lib.list_local_devices())


def render(env):
    return env.environment.render(mode='rgb_array')


environment = gym_wrapper.GymWrapper(gym.make('LunarLander-v2'))
environment = wrappers.SinglePrecisionWrapper(environment)
environment_spec = specs.make_environment_spec(environment)

model = Models.sequential_model(
    input_shape=environment_spec.observations.shape,
    num_outputs=environment_spec.actions.num_values,
    hidden_layers=3,
    layer_size=300)

agent = DQNAgent(environment_spec=environment_spec, network=model)

logger = loggers.TerminalLogger(time_delta=10.)
loop = acme.EnvironmentLoop(environment=environment, actor=agent)
loop.run()
Пример #22
0
    def test_ail(self,
                 algo,
                 airl_discriminator=False,
                 subtract_logpi=False,
                 dropout=0.,
                 lipschitz_coeff=None):
        shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        if algo == 'ppo':
            environment = fakes.DiscreteEnvironment(
                num_actions=NUM_DISCRETE_ACTIONS,
                num_observations=NUM_OBSERVATIONS,
                obs_shape=OBS_SHAPE,
                obs_dtype=OBS_DTYPE,
                episode_length=EPISODE_LENGTH)
        else:
            environment = fakes.ContinuousEnvironment(
                episode_length=EPISODE_LENGTH,
                action_dim=CONTINUOUS_ACTION_DIM,
                observation_dim=CONTINUOUS_OBS_DIM,
                bounded=True)
        spec = specs.make_environment_spec(environment)

        if algo == 'sac':
            networks = sac.make_networks(spec=spec)
            config = sac.SACConfig(
                batch_size=batch_size,
                samples_per_insert_tolerance_rate=float('inf'),
                min_replay_size=1)
            base_builder = sac.SACBuilder(config=config)
            direct_rl_batch_size = batch_size
            behavior_policy = sac.apply_policy_and_sample(networks)
        elif algo == 'ppo':
            unroll_length = 5
            distribution_value_networks = make_ppo_networks(spec)
            networks = ppo.make_ppo_networks(distribution_value_networks)
            config = ppo.PPOConfig(unroll_length=unroll_length,
                                   num_minibatches=2,
                                   num_epochs=4,
                                   batch_size=batch_size)
            base_builder = ppo.PPOBuilder(config=config)
            direct_rl_batch_size = batch_size * unroll_length
            behavior_policy = jax.jit(ppo.make_inference_fn(networks),
                                      backend='cpu')
        else:
            raise ValueError(f'Unexpected algorithm {algo}')

        if subtract_logpi:
            assert algo == 'sac'
            logpi_fn = make_sac_logpi(networks)
        else:
            logpi_fn = None

        if algo == 'ppo':
            embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1])
        else:
            embedding = lambda x: x

        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            if airl_discriminator:
                return ail.AIRLModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    discount=.99,
                    g_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    h_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)
            else:
                return ail.DiscriminatorModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    network_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        discriminator_network = ail.make_discriminator(
            environment_spec=spec,
            discriminator_transformed=discriminator_transformed,
            logpi_fn=logpi_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(
                is_sequence_based=(algo == 'ppo'),
                share_iterator=True,
                direct_rl_batch_size=direct_rl_batch_size,
                discriminator_batch_size=2,
                policy_variable_name='policy' if subtract_logpi else None,
                min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        # Construct the agent.
        agent = local_layout.LocalLayout(seed=0,
                                         environment_spec=spec,
                                         builder=builder,
                                         networks=networks,
                                         policy_network=behavior_policy,
                                         min_replay_size=1,
                                         batch_size=batch_size)

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent)
        train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))