Exemplo n.º 1
0
    def __init__(
        self,
        environment_factory: jax_types.EnvironmentFactory,
        network_factory: NetworkFactory,
        sac_fd_config: SACfDConfig,
        lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]],
        seed: int,
        num_actors: int,
        environment_spec: Optional[specs.EnvironmentSpec] = None,
        max_number_of_steps: Optional[int] = None,
        log_to_bigtable: bool = False,
        log_every: float = 10.0,
        evaluator_factories: Optional[Sequence[
            distributed_layout.EvaluatorFactory]] = None,
    ):
        logger_fn = functools.partial(loggers.make_default_logger,
                                      'learner',
                                      log_to_bigtable,
                                      time_delta=log_every,
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')

        sac_config = sac_fd_config.sac_config
        lfd_config = sac_fd_config.lfd_config
        sac_builder = sac.SACBuilder(sac_config, logger_fn=logger_fn)
        lfd_builder = builder.LfdBuilder(sac_builder, lfd_iterator_fn,
                                         lfd_config)

        if evaluator_factories is None:
            eval_policy_factory = (
                lambda n: sac.apply_policy_and_sample(n, True))
            evaluator_factories = [
                distributed_layout.default_evaluator_factory(
                    environment_factory=environment_factory,
                    network_factory=network_factory,
                    policy_factory=eval_policy_factory,
                    log_to_bigtable=log_to_bigtable)
            ]

        super().__init__(
            seed=seed,
            environment_factory=environment_factory,
            network_factory=network_factory,
            environment_spec=environment_spec,
            builder=lfd_builder,
            policy_network=sac.apply_policy_and_sample,
            evaluator_factories=evaluator_factories,
            num_actors=num_actors,
            max_number_of_steps=max_number_of_steps,
            prefetch_size=sac_config.prefetch_size,
            log_to_bigtable=log_to_bigtable,
            actor_logger_fn=distributed_layout.get_default_logger_fn(
                log_to_bigtable, log_every))
Exemplo n.º 2
0
def main(_):
    task = FLAGS.task
    environment_factory = lambda seed: helpers.make_environment(task)
    sac_config = sac.SACConfig(num_sgd_steps_per_step=64)
    sac_builder = sac.SACBuilder(sac_config)

    ail_config = ail.AILConfig(direct_rl_batch_size=sac_config.batch_size *
                               sac_config.num_sgd_steps_per_step)

    def network_factory(spec: specs.EnvironmentSpec) -> ail.AILNetworks:
        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            return ail.DiscriminatorModule(environment_spec=spec,
                                           use_action=True,
                                           use_next_obs=True,
                                           network_core=ail.DiscriminatorMLP(
                                               [4, 4], ))(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        return ail.AILNetworks(ail.make_discriminator(
            spec, discriminator_transformed),
                               imitation_reward_fn=ail.rewards.gail_reward(),
                               direct_rl_networks=sac.make_networks(spec))

    def policy_network(
            network: ail.AILNetworks,
            eval_mode: bool = False) -> actor_core_lib.FeedForwardPolicy:
        return sac.apply_policy_and_sample(network.direct_rl_networks,
                                           eval_mode=eval_mode)

    program = ail.DistributedAIL(
        environment_factory=environment_factory,
        rl_agent=sac_builder,
        config=ail_config,
        network_factory=network_factory,
        seed=0,
        batch_size=sac_config.batch_size * sac_config.num_sgd_steps_per_step,
        make_demonstrations=functools.partial(
            helpers.make_demonstration_iterator,
            dataset_name=FLAGS.dataset_name),
        policy_network=policy_network,
        evaluator_policy_network=(lambda n: policy_network(n, eval_mode=True)),
        num_actors=4,
        max_number_of_steps=100,
        discriminator_loss=ail.losses.gail_loss()).build()

    # Launch experiment.
    lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
Exemplo n.º 3
0
    def __init__(self,
                 spec: specs.EnvironmentSpec,
                 sac_network: sac.SACNetworks,
                 sac_fd_config: SACfDConfig,
                 lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]],
                 seed: int,
                 counter: Optional[counting.Counter] = None):
        """New instance of a SACfD agent."""
        sac_config = sac_fd_config.sac_config
        lfd_config = sac_fd_config.lfd_config
        sac_builder = sac.SACBuilder(sac_config)
        lfd_builder = builder.LfdBuilder(sac_builder, lfd_iterator_fn,
                                         lfd_config)

        min_replay_size = sac_config.min_replay_size
        # Local layout (actually agent.Agent) makes sure that we populate the
        # buffer with min_replay_size initial transitions and that there's no need
        # for tolerance_rate. In order for deadlocks not to happen we need to
        # disable rate limiting that heppens inside the SACBuilder. This is achieved
        # by the following two lines.
        sac_config.samples_per_insert_tolerance_rate = float('inf')
        sac_config.min_replay_size = 1

        self.builder = lfd_builder
        super().__init__(
            builder=lfd_builder,
            seed=seed,
            environment_spec=spec,
            networks=sac_network,
            policy_network=sac.apply_policy_and_sample(sac_network),
            batch_size=sac_config.batch_size,
            prefetch_size=sac_config.prefetch_size,
            samples_per_insert=sac_config.samples_per_insert,
            min_replay_size=min_replay_size,
            num_sgd_steps_per_step=sac_config.num_sgd_steps_per_step,
            counter=counter,
        )
Exemplo n.º 4
0
 def __init__(self, sac_fd_config: SACfDConfig,
              lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]):
   sac_builder = sac.SACBuilder(sac_fd_config.sac_config)
   super().__init__(sac_builder, lfd_iterator_fn, sac_fd_config.lfd_config)
Exemplo n.º 5
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = sac.make_networks(environment_spec)

    # Construct the agent.
    # Local layout makes sure that we populate the buffer with min_replay_size
    # initial transitions and that there's no need for tolerance_rate. In order
    # for deadlocks not to happen we need to disable rate limiting that heppens
    # inside the SACBuilder. This is achieved by the min_replay_size and
    # samples_per_insert_tolerance_rate arguments.
    sac_config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=1,
        samples_per_insert_tolerance_rate=float('inf'))
    sac_builder = sac.SACBuilder(sac_config)
    sac_networks = sac.make_networks(environment_spec)
    sac_networks = add_bc_pretraining(sac_networks)

    ail_config = ail.AILConfig(direct_rl_batch_size=sac_config.batch_size *
                               sac_config.num_sgd_steps_per_step)

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        return ail.DiscriminatorModule(environment_spec=environment_spec,
                                       use_action=True,
                                       use_next_obs=True,
                                       network_core=ail.DiscriminatorMLP(
                                           [4, 4], ))(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=sac_networks)

    agent = ail.AIL(spec=environment_spec,
                    rl_agent=sac_builder,
                    network=ail_network,
                    config=ail_config,
                    seed=FLAGS.seed,
                    batch_size=sac_config.batch_size *
                    sac_config.num_sgd_steps_per_step,
                    make_demonstrations=functools.partial(
                        helpers.make_demonstration_iterator,
                        dataset_name=FLAGS.dataset_name),
                    policy_network=sac.apply_policy_and_sample(sac_networks),
                    discriminator_loss=ail.losses.gail_loss())

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=sac.apply_policy_and_sample(agent_networks,
                                                   eval_mode=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Exemplo n.º 6
0
    def test_ail_flax(self):
        shutil.rmtree(flags.FLAGS.test_tmpdir)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        environment = fakes.ContinuousEnvironment(
            episode_length=EPISODE_LENGTH,
            action_dim=CONTINUOUS_ACTION_DIM,
            observation_dim=CONTINUOUS_OBS_DIM,
            bounded=True)
        spec = specs.make_environment_spec(environment)

        networks = sac.make_networks(spec=spec)
        config = sac.SACConfig(batch_size=batch_size,
                               samples_per_insert_tolerance_rate=float('inf'),
                               min_replay_size=1)
        base_builder = sac.SACBuilder(config=config)
        direct_rl_batch_size = batch_size
        behavior_policy = sac.apply_policy_and_sample(networks)

        discriminator_module = DiscriminatorModule(spec, linen.Dense(1))

        def apply_fn(params: networks_lib.Params,
                     policy_params: networks_lib.Params,
                     state: networks_lib.Params, transitions: types.Transition,
                     is_training: bool,
                     rng: networks_lib.PRNGKey) -> networks_lib.Logits:
            del policy_params
            variables = dict(params=params, **state)
            return discriminator_module.apply(variables,
                                              transitions.observation,
                                              transitions.action,
                                              transitions.next_observation,
                                              is_training=is_training,
                                              rng=rng,
                                              mutable=state.keys())

        def init_fn(rng):
            variables = discriminator_module.init(rng,
                                                  dummy_obs,
                                                  dummy_actions,
                                                  dummy_obs,
                                                  is_training=False,
                                                  rng=rng)
            init_state, discriminator_params = variables.pop('params')
            return discriminator_params, init_state

        dummy_obs = utils.zeros_like(spec.observations)
        dummy_obs = utils.add_batch_dim(dummy_obs)
        dummy_actions = utils.zeros_like(spec.actions)
        dummy_actions = utils.add_batch_dim(dummy_actions)
        discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn,
                                                                apply=apply_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(is_sequence_based=False,
                                 share_iterator=True,
                                 direct_rl_batch_size=direct_rl_batch_size,
                                 discriminator_batch_size=2,
                                 policy_variable_name=None,
                                 min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        counter = counting.Counter()
        # Construct the agent.
        agent = local_layout.LocalLayout(
            seed=0,
            environment_spec=spec,
            builder=builder,
            networks=networks,
            policy_network=behavior_policy,
            min_replay_size=1,
            batch_size=batch_size,
            counter=counter,
        )

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        train_loop.run(num_episodes=1)
Exemplo n.º 7
0
    def test_ail(self,
                 algo,
                 airl_discriminator=False,
                 subtract_logpi=False,
                 dropout=0.,
                 lipschitz_coeff=None):
        shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        if algo == 'ppo':
            environment = fakes.DiscreteEnvironment(
                num_actions=NUM_DISCRETE_ACTIONS,
                num_observations=NUM_OBSERVATIONS,
                obs_shape=OBS_SHAPE,
                obs_dtype=OBS_DTYPE,
                episode_length=EPISODE_LENGTH)
        else:
            environment = fakes.ContinuousEnvironment(
                episode_length=EPISODE_LENGTH,
                action_dim=CONTINUOUS_ACTION_DIM,
                observation_dim=CONTINUOUS_OBS_DIM,
                bounded=True)
        spec = specs.make_environment_spec(environment)

        if algo == 'sac':
            networks = sac.make_networks(spec=spec)
            config = sac.SACConfig(
                batch_size=batch_size,
                samples_per_insert_tolerance_rate=float('inf'),
                min_replay_size=1)
            base_builder = sac.SACBuilder(config=config)
            direct_rl_batch_size = batch_size
            behavior_policy = sac.apply_policy_and_sample(networks)
        elif algo == 'ppo':
            unroll_length = 5
            distribution_value_networks = make_ppo_networks(spec)
            networks = ppo.make_ppo_networks(distribution_value_networks)
            config = ppo.PPOConfig(unroll_length=unroll_length,
                                   num_minibatches=2,
                                   num_epochs=4,
                                   batch_size=batch_size)
            base_builder = ppo.PPOBuilder(config=config)
            direct_rl_batch_size = batch_size * unroll_length
            behavior_policy = jax.jit(ppo.make_inference_fn(networks),
                                      backend='cpu')
        else:
            raise ValueError(f'Unexpected algorithm {algo}')

        if subtract_logpi:
            assert algo == 'sac'
            logpi_fn = make_sac_logpi(networks)
        else:
            logpi_fn = None

        if algo == 'ppo':
            embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1])
        else:
            embedding = lambda x: x

        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            if airl_discriminator:
                return ail.AIRLModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    discount=.99,
                    g_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    h_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)
            else:
                return ail.DiscriminatorModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    network_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        discriminator_network = ail.make_discriminator(
            environment_spec=spec,
            discriminator_transformed=discriminator_transformed,
            logpi_fn=logpi_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(
                is_sequence_based=(algo == 'ppo'),
                share_iterator=True,
                direct_rl_batch_size=direct_rl_batch_size,
                discriminator_batch_size=2,
                policy_variable_name='policy' if subtract_logpi else None,
                min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        # Construct the agent.
        agent = local_layout.LocalLayout(seed=0,
                                         environment_spec=spec,
                                         builder=builder,
                                         networks=networks,
                                         policy_network=behavior_policy,
                                         min_replay_size=1,
                                         batch_size=batch_size)

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent)
        train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))