def __init__( self, environment_factory: jax_types.EnvironmentFactory, network_factory: NetworkFactory, sac_fd_config: SACfDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], seed: int, num_actors: int, environment_spec: Optional[specs.EnvironmentSpec] = None, max_number_of_steps: Optional[int] = None, log_to_bigtable: bool = False, log_every: float = 10.0, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') sac_config = sac_fd_config.sac_config lfd_config = sac_fd_config.lfd_config sac_builder = sac.SACBuilder(sac_config, logger_fn=logger_fn) lfd_builder = builder.LfdBuilder(sac_builder, lfd_iterator_fn, lfd_config) if evaluator_factories is None: eval_policy_factory = ( lambda n: sac.apply_policy_and_sample(n, True)) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_policy_factory, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, environment_spec=environment_spec, builder=lfd_builder, policy_network=sac.apply_policy_and_sample, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=sac_config.prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every))
def main(_): task = FLAGS.task environment_factory = lambda seed: helpers.make_environment(task) sac_config = sac.SACConfig(num_sgd_steps_per_step=64) sac_builder = sac.SACBuilder(sac_config) ail_config = ail.AILConfig(direct_rl_batch_size=sac_config.batch_size * sac_config.num_sgd_steps_per_step) def network_factory(spec: specs.EnvironmentSpec) -> ail.AILNetworks: def discriminator(*args, **kwargs) -> networks_lib.Logits: return ail.DiscriminatorModule(environment_spec=spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], ))(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) return ail.AILNetworks(ail.make_discriminator( spec, discriminator_transformed), imitation_reward_fn=ail.rewards.gail_reward(), direct_rl_networks=sac.make_networks(spec)) def policy_network( network: ail.AILNetworks, eval_mode: bool = False) -> actor_core_lib.FeedForwardPolicy: return sac.apply_policy_and_sample(network.direct_rl_networks, eval_mode=eval_mode) program = ail.DistributedAIL( environment_factory=environment_factory, rl_agent=sac_builder, config=ail_config, network_factory=network_factory, seed=0, batch_size=sac_config.batch_size * sac_config.num_sgd_steps_per_step, make_demonstrations=functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name), policy_network=policy_network, evaluator_policy_network=(lambda n: policy_network(n, eval_mode=True)), num_actors=4, max_number_of_steps=100, discriminator_loss=ail.losses.gail_loss()).build() # Launch experiment. lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
def __init__(self, spec: specs.EnvironmentSpec, sac_network: sac.SACNetworks, sac_fd_config: SACfDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], seed: int, counter: Optional[counting.Counter] = None): """New instance of a SACfD agent.""" sac_config = sac_fd_config.sac_config lfd_config = sac_fd_config.lfd_config sac_builder = sac.SACBuilder(sac_config) lfd_builder = builder.LfdBuilder(sac_builder, lfd_iterator_fn, lfd_config) min_replay_size = sac_config.min_replay_size # Local layout (actually agent.Agent) makes sure that we populate the # buffer with min_replay_size initial transitions and that there's no need # for tolerance_rate. In order for deadlocks not to happen we need to # disable rate limiting that heppens inside the SACBuilder. This is achieved # by the following two lines. sac_config.samples_per_insert_tolerance_rate = float('inf') sac_config.min_replay_size = 1 self.builder = lfd_builder super().__init__( builder=lfd_builder, seed=seed, environment_spec=spec, networks=sac_network, policy_network=sac.apply_policy_and_sample(sac_network), batch_size=sac_config.batch_size, prefetch_size=sac_config.prefetch_size, samples_per_insert=sac_config.samples_per_insert, min_replay_size=min_replay_size, num_sgd_steps_per_step=sac_config.num_sgd_steps_per_step, counter=counter, )
def __init__(self, sac_fd_config: SACfDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): sac_builder = sac.SACBuilder(sac_fd_config.sac_config) super().__init__(sac_builder, lfd_iterator_fn, sac_fd_config.lfd_config)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) agent_networks = sac.make_networks(environment_spec) # Construct the agent. # Local layout makes sure that we populate the buffer with min_replay_size # initial transitions and that there's no need for tolerance_rate. In order # for deadlocks not to happen we need to disable rate limiting that heppens # inside the SACBuilder. This is achieved by the min_replay_size and # samples_per_insert_tolerance_rate arguments. sac_config = sac.SACConfig( target_entropy=sac.target_entropy_from_env_spec(environment_spec), num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step, min_replay_size=1, samples_per_insert_tolerance_rate=float('inf')) sac_builder = sac.SACBuilder(sac_config) sac_networks = sac.make_networks(environment_spec) sac_networks = add_bc_pretraining(sac_networks) ail_config = ail.AILConfig(direct_rl_batch_size=sac_config.batch_size * sac_config.num_sgd_steps_per_step) def discriminator(*args, **kwargs) -> networks_lib.Logits: return ail.DiscriminatorModule(environment_spec=environment_spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], ))(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) ail_network = ail.AILNetworks( ail.make_discriminator(environment_spec, discriminator_transformed), imitation_reward_fn=ail.rewards.gail_reward(), direct_rl_networks=sac_networks) agent = ail.AIL(spec=environment_spec, rl_agent=sac_builder, network=ail_network, config=ail_config, seed=FLAGS.seed, batch_size=sac_config.batch_size * sac_config.num_sgd_steps_per_step, make_demonstrations=functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name), policy_network=sac.apply_policy_and_sample(sac_networks), discriminator_loss=ail.losses.gail_loss()) # Create the environment loop used for training. train_logger = experiment_utils.make_experiment_logger( label='train', steps_key='train_steps') train_loop = acme.EnvironmentLoop(environment, agent, counter=counting.Counter(prefix='train'), logger=train_logger) # Create the evaluation actor and loop. eval_logger = experiment_utils.make_experiment_logger( label='eval', steps_key='eval_steps') eval_actor = agent.builder.make_actor( random_key=jax.random.PRNGKey(FLAGS.seed), policy_network=sac.apply_policy_and_sample(agent_networks, eval_mode=True), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, counter=counting.Counter(prefix='eval'), logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)
def test_ail_flax(self): shutil.rmtree(flags.FLAGS.test_tmpdir) batch_size = 8 # Mujoco environment and associated demonstration dataset. environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) networks = sac.make_networks(spec=spec) config = sac.SACConfig(batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) discriminator_module = DiscriminatorModule(spec, linen.Dense(1)) def apply_fn(params: networks_lib.Params, policy_params: networks_lib.Params, state: networks_lib.Params, transitions: types.Transition, is_training: bool, rng: networks_lib.PRNGKey) -> networks_lib.Logits: del policy_params variables = dict(params=params, **state) return discriminator_module.apply(variables, transitions.observation, transitions.action, transitions.next_observation, is_training=is_training, rng=rng, mutable=state.keys()) def init_fn(rng): variables = discriminator_module.init(rng, dummy_obs, dummy_actions, dummy_obs, is_training=False, rng=rng) init_state, discriminator_params = variables.pop('params') return discriminator_params, init_state dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) dummy_actions = utils.zeros_like(spec.actions) dummy_actions = utils.add_batch_dim(dummy_actions) discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn, apply=apply_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig(is_sequence_based=False, share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name=None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) counter = counting.Counter() # Construct the agent. agent = local_layout.LocalLayout( seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size, counter=counter, ) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent, counter=counter) train_loop.run(num_episodes=1)
def test_ail(self, algo, airl_discriminator=False, subtract_logpi=False, dropout=0., lipschitz_coeff=None): shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True) batch_size = 8 # Mujoco environment and associated demonstration dataset. if algo == 'ppo': environment = fakes.DiscreteEnvironment( num_actions=NUM_DISCRETE_ACTIONS, num_observations=NUM_OBSERVATIONS, obs_shape=OBS_SHAPE, obs_dtype=OBS_DTYPE, episode_length=EPISODE_LENGTH) else: environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) if algo == 'sac': networks = sac.make_networks(spec=spec) config = sac.SACConfig( batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) elif algo == 'ppo': unroll_length = 5 distribution_value_networks = make_ppo_networks(spec) networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=unroll_length, num_minibatches=2, num_epochs=4, batch_size=batch_size) base_builder = ppo.PPOBuilder(config=config) direct_rl_batch_size = batch_size * unroll_length behavior_policy = jax.jit(ppo.make_inference_fn(networks), backend='cpu') else: raise ValueError(f'Unexpected algorithm {algo}') if subtract_logpi: assert algo == 'sac' logpi_fn = make_sac_logpi(networks) else: logpi_fn = None if algo == 'ppo': embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1]) else: embedding = lambda x: x def discriminator(*args, **kwargs) -> networks_lib.Logits: if airl_discriminator: return ail.AIRLModule( environment_spec=spec, use_action=True, use_next_obs=True, discount=.99, g_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), h_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) else: return ail.DiscriminatorModule( environment_spec=spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) discriminator_network = ail.make_discriminator( environment_spec=spec, discriminator_transformed=discriminator_transformed, logpi_fn=logpi_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig( is_sequence_based=(algo == 'ppo'), share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name='policy' if subtract_logpi else None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) # Construct the agent. agent = local_layout.LocalLayout(seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent) train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))