Exemple #1
0
    def network_factory(spec: specs.EnvironmentSpec) -> ail.AILNetworks:
        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            return ail.DiscriminatorModule(environment_spec=spec,
                                           use_action=True,
                                           use_next_obs=True,
                                           network_core=ail.DiscriminatorMLP(
                                               [4, 4], ))(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        return ail.AILNetworks(ail.make_discriminator(
            spec, discriminator_transformed),
                               imitation_reward_fn=ail.rewards.gail_reward(),
                               direct_rl_networks=sac.make_networks(spec))
Exemple #2
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Construct the agent.
    # Local layout makes sure that we populate the buffer with min_replay_size
    # initial transitions and that there's no need for tolerance_rate. In order
    # for deadlocks not to happen we need to disable rate limiting that heppens
    # inside the TD3Builder. This is achieved by the min_replay_size and
    # samples_per_insert_tolerance_rate arguments.
    td3_config = td3.TD3Config(
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=1,
        samples_per_insert_tolerance_rate=float('inf'))
    td3_networks = td3.make_networks(environment_spec)
    if FLAGS.pretrain:
        td3_networks = add_bc_pretraining(td3_networks)

    ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size *
                               td3_config.num_sgd_steps_per_step)
    dac_config = ail.DACConfig(ail_config, td3_config)

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        return ail.DiscriminatorModule(environment_spec=environment_spec,
                                       use_action=True,
                                       use_next_obs=True,
                                       network_core=ail.DiscriminatorMLP(
                                           [4, 4], ))(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=td3_networks)

    agent = ail.DAC(spec=environment_spec,
                    network=ail_network,
                    config=dac_config,
                    seed=FLAGS.seed,
                    batch_size=td3_config.batch_size *
                    td3_config.num_sgd_steps_per_step,
                    make_demonstrations=functools.partial(
                        helpers.make_demonstration_iterator,
                        dataset_name=FLAGS.dataset_name),
                    policy_network=td3.get_default_behavior_policy(
                        td3_networks,
                        action_specs=environment_spec.actions,
                        sigma=td3_config.sigma))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    # TODO(lukstafi): sigma=0 for eval?
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=td3.get_default_behavior_policy(
            td3_networks, action_specs=environment_spec.actions, sigma=0.),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Exemple #3
0
    def test_ail_flax(self):
        shutil.rmtree(flags.FLAGS.test_tmpdir)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        environment = fakes.ContinuousEnvironment(
            episode_length=EPISODE_LENGTH,
            action_dim=CONTINUOUS_ACTION_DIM,
            observation_dim=CONTINUOUS_OBS_DIM,
            bounded=True)
        spec = specs.make_environment_spec(environment)

        networks = sac.make_networks(spec=spec)
        config = sac.SACConfig(batch_size=batch_size,
                               samples_per_insert_tolerance_rate=float('inf'),
                               min_replay_size=1)
        base_builder = sac.SACBuilder(config=config)
        direct_rl_batch_size = batch_size
        behavior_policy = sac.apply_policy_and_sample(networks)

        discriminator_module = DiscriminatorModule(spec, linen.Dense(1))

        def apply_fn(params: networks_lib.Params,
                     policy_params: networks_lib.Params,
                     state: networks_lib.Params, transitions: types.Transition,
                     is_training: bool,
                     rng: networks_lib.PRNGKey) -> networks_lib.Logits:
            del policy_params
            variables = dict(params=params, **state)
            return discriminator_module.apply(variables,
                                              transitions.observation,
                                              transitions.action,
                                              transitions.next_observation,
                                              is_training=is_training,
                                              rng=rng,
                                              mutable=state.keys())

        def init_fn(rng):
            variables = discriminator_module.init(rng,
                                                  dummy_obs,
                                                  dummy_actions,
                                                  dummy_obs,
                                                  is_training=False,
                                                  rng=rng)
            init_state, discriminator_params = variables.pop('params')
            return discriminator_params, init_state

        dummy_obs = utils.zeros_like(spec.observations)
        dummy_obs = utils.add_batch_dim(dummy_obs)
        dummy_actions = utils.zeros_like(spec.actions)
        dummy_actions = utils.add_batch_dim(dummy_actions)
        discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn,
                                                                apply=apply_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(is_sequence_based=False,
                                 share_iterator=True,
                                 direct_rl_batch_size=direct_rl_batch_size,
                                 discriminator_batch_size=2,
                                 policy_variable_name=None,
                                 min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        counter = counting.Counter()
        # Construct the agent.
        agent = local_layout.LocalLayout(
            seed=0,
            environment_spec=spec,
            builder=builder,
            networks=networks,
            policy_network=behavior_policy,
            min_replay_size=1,
            batch_size=batch_size,
            counter=counter,
        )

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        train_loop.run(num_episodes=1)
Exemple #4
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = ppo.make_continuous_networks(environment_spec)

    # Construct the agent.
    ppo_config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length,
                               num_minibatches=FLAGS.ppo_num_minibatches,
                               num_epochs=FLAGS.ppo_num_epochs,
                               batch_size=FLAGS.transition_batch_size //
                               FLAGS.unroll_length,
                               learning_rate=0.0003,
                               entropy_cost=0,
                               gae_lambda=0.8,
                               value_cost=0.25)
    ppo_networks = ppo.make_continuous_networks(environment_spec)
    if FLAGS.pretrain:
        ppo_networks = add_bc_pretraining(ppo_networks)

    discriminator_batch_size = FLAGS.transition_batch_size
    ail_config = ail.AILConfig(
        direct_rl_batch_size=ppo_config.batch_size * ppo_config.unroll_length,
        discriminator_batch_size=discriminator_batch_size,
        is_sequence_based=True,
        num_sgd_steps_per_step=FLAGS.num_discriminator_steps_per_step,
        share_iterator=FLAGS.share_iterator,
    )

    def discriminator(*args, **kwargs) -> networks_lib.Logits:
        # Note: observation embedding is not needed for e.g. Mujoco.
        return ail.DiscriminatorModule(
            environment_spec=environment_spec,
            use_action=True,
            use_next_obs=True,
            network_core=ail.DiscriminatorMLP([4, 4], ),
        )(*args, **kwargs)

    discriminator_transformed = hk.without_apply_rng(
        hk.transform_with_state(discriminator))

    ail_network = ail.AILNetworks(
        ail.make_discriminator(environment_spec, discriminator_transformed),
        imitation_reward_fn=ail.rewards.gail_reward(),
        direct_rl_networks=ppo_networks)

    agent = ail.GAIL(spec=environment_spec,
                     network=ail_network,
                     config=ail.GAILConfig(ail_config, ppo_config),
                     seed=FLAGS.seed,
                     batch_size=ppo_config.batch_size,
                     make_demonstrations=functools.partial(
                         helpers.make_demonstration_iterator,
                         dataset_name=FLAGS.dataset_name),
                     policy_network=ppo.make_inference_fn(ppo_networks))

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=ppo.make_inference_fn(agent_networks, evaluation=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Exemple #5
0
    def test_ail(self,
                 algo,
                 airl_discriminator=False,
                 subtract_logpi=False,
                 dropout=0.,
                 lipschitz_coeff=None):
        shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        if algo == 'ppo':
            environment = fakes.DiscreteEnvironment(
                num_actions=NUM_DISCRETE_ACTIONS,
                num_observations=NUM_OBSERVATIONS,
                obs_shape=OBS_SHAPE,
                obs_dtype=OBS_DTYPE,
                episode_length=EPISODE_LENGTH)
        else:
            environment = fakes.ContinuousEnvironment(
                episode_length=EPISODE_LENGTH,
                action_dim=CONTINUOUS_ACTION_DIM,
                observation_dim=CONTINUOUS_OBS_DIM,
                bounded=True)
        spec = specs.make_environment_spec(environment)

        if algo == 'sac':
            networks = sac.make_networks(spec=spec)
            config = sac.SACConfig(
                batch_size=batch_size,
                samples_per_insert_tolerance_rate=float('inf'),
                min_replay_size=1)
            base_builder = sac.SACBuilder(config=config)
            direct_rl_batch_size = batch_size
            behavior_policy = sac.apply_policy_and_sample(networks)
        elif algo == 'ppo':
            unroll_length = 5
            distribution_value_networks = make_ppo_networks(spec)
            networks = ppo.make_ppo_networks(distribution_value_networks)
            config = ppo.PPOConfig(unroll_length=unroll_length,
                                   num_minibatches=2,
                                   num_epochs=4,
                                   batch_size=batch_size)
            base_builder = ppo.PPOBuilder(config=config)
            direct_rl_batch_size = batch_size * unroll_length
            behavior_policy = jax.jit(ppo.make_inference_fn(networks),
                                      backend='cpu')
        else:
            raise ValueError(f'Unexpected algorithm {algo}')

        if subtract_logpi:
            assert algo == 'sac'
            logpi_fn = make_sac_logpi(networks)
        else:
            logpi_fn = None

        if algo == 'ppo':
            embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1])
        else:
            embedding = lambda x: x

        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            if airl_discriminator:
                return ail.AIRLModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    discount=.99,
                    g_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    h_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)
            else:
                return ail.DiscriminatorModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    network_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        discriminator_network = ail.make_discriminator(
            environment_spec=spec,
            discriminator_transformed=discriminator_transformed,
            logpi_fn=logpi_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(
                is_sequence_based=(algo == 'ppo'),
                share_iterator=True,
                direct_rl_batch_size=direct_rl_batch_size,
                discriminator_batch_size=2,
                policy_variable_name='policy' if subtract_logpi else None,
                min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        # Construct the agent.
        agent = local_layout.LocalLayout(seed=0,
                                         environment_spec=spec,
                                         builder=builder,
                                         networks=networks,
                                         policy_network=behavior_policy,
                                         min_replay_size=1,
                                         batch_size=batch_size)

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent)
        train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))