def test_train(self, policy_loss_coeff_fn): seed = 0 num_iterations = 5 batch_size = 64 grad_updates_per_batch = 1 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) # Construct the learner. networks = crr.make_networks(spec, policy_layer_sizes=(8, 8), critic_layer_sizes=(8, 8)) key = jax.random.PRNGKey(seed) dataset = fakes.transition_iterator(environment) learner = crr.CRRLearner(networks, key, discount=0.95, target_update_period=2, policy_loss_coeff_fn=policy_loss_coeff_fn, iterator=dataset(batch_size * grad_updates_per_batch), policy_optimizer=optax.adam(1e-4), critic_optimizer=optax.adam(1e-4), grad_updates_per_batch=grad_updates_per_batch) # Train the learner. for _ in range(num_iterations): learner.step()
def test_d4pg(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. networks = make_networks(spec) config = d4pg.D4PGConfig( batch_size=10, samples_per_insert=2, min_replay_size=10, samples_per_insert_tolerance_rate=float('inf')) counter = counting.Counter() agent = d4pg.D4PG(spec, networks, config=config, random_seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=2)
def test_mompo(self, distributional_critic): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10) spec = specs.make_environment_spec(environment) # Create objectives. reward_objectives, qvalue_objectives = make_objectives() num_critic_heads = len(reward_objectives) # Create networks. agent_networks = make_networks( spec.actions, num_critic_heads=num_critic_heads, distributional_critic=distributional_critic) # Construct the agent. agent = mompo.MultiObjectiveMPO( reward_objectives, qvalue_objectives, spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def test_make_dataset_with_sequence_length_and_batch_size(self): sequence_length = 6 batch_size = 4 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, batch_size=batch_size, sequence_length=sequence_length) def make_tensor_spec(spec): return tf.TensorSpec(shape=( batch_size, sequence_length, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) expected_spec = adders.Step(observation=expected_spec.observations, action=expected_spec.actions, reward=expected_spec.rewards, discount=expected_spec.discounts, start_of_episode=specs.Array( shape=(batch_size, sequence_length), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def test_td3_fd(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment( episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. td3_network = td3.make_networks(spec) batch_size = 10 td3_config = td3.TD3Config( batch_size=batch_size, min_replay_size=1) lfd_config = lfd.LfdConfig(initial_insert_count=0, demonstration_ratio=0.2) td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config, td3_config=td3_config) counter = counting.Counter() agent = lfd.TD3fD( spec=spec, td3_network=td3_network, td3_fd_config=td3_fd_config, lfd_iterator_fn=fake_demonstration_iterator, seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def test_train(self): seed = 0 num_iterations = 2 batch_size = 64 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) # Construct the agent. networks = cql.make_networks(spec, hidden_layer_sizes=(8, 8)) dataset = fakes.transition_iterator(environment) key = jax.random.PRNGKey(seed) learner = cql.CQLLearner(batch_size, networks, key, demonstrations=dataset(batch_size), policy_optimizer=optax.adam(3e-5), critic_optimizer=optax.adam(3e-4), cql_lagrange_threshold=1., target_entropy=0.1, num_sgd_steps_per_step=1) # Train the agent for _ in range(num_iterations): learner.step()
def test_value_dice(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. network = value_dice.make_networks(spec) config = value_dice.ValueDiceConfig(batch_size=10, min_replay_size=1) counter = counting.Counter() agent = value_dice.ValueDice( spec=spec, network=network, config=config, make_demonstrations=fakes.transition_iterator(environment), seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=2)
def test_make_dataset_simple(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) self.assertTrue( _check_specs(tuple(environment_spec), dataset.element_spec.data))
def test_make_dataset_transition_adder(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, transition_adder=True) environment_spec = tuple(environment_spec) + ( environment_spec.observations, ) self.assertTrue( _check_specs(tuple(environment_spec), dataset.element_spec.data))
def test_make_dataset_simple(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) expected_spec = adders.Step(observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, start_of_episode=specs.Array(shape=(), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def test_make_dataset_with_batch_size(self): batch_size = 4 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, batch_size=batch_size) def make_tensor_spec(spec): return tf.TensorSpec(shape=(None, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) self.assertTrue( _check_specs(tuple(expected_spec), dataset.element_spec.data))
def test_make_dataset_transition_adder(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( server_address=self.server_address, environment_spec=environment_spec, transition_adder=True) environment_spec = types.Transition( observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, next_observation=environment_spec.observations, extras=()) self.assertTrue( _check_specs(environment_spec, dataset.element_spec.data))
def test_continuous_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec) if loss_name == 'logp': loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) elif loss_name == 'mse': loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params .sample(seed=key)) elif loss_name == 'peerbc': base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions: dist_params.log_prob(actions)) loss_fn = bc.peerbc(base_loss_fn, zeta=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def test_continuous(self): env = wrappers.SinglePrecisionWrapper( fakes.ContinuousEnvironment( action_dim=0, dtype=np.float64, reward_dtype=np.float64)) self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.float32)) self.assertTrue(np.issubdtype(env.action_spec().dtype, np.float32)) self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) timestep = env.reset() self.assertEqual(timestep.reward, None) self.assertEqual(timestep.discount, None) self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32)) timestep = env.step(0.0) self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32))
def test_dmpo(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10) spec = specs.make_environment_spec(environment) # Create networks. agent_networks = make_networks(spec.actions) # Construct the agent. agent = dmpo.DistributionalMPO(spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def test_control_suite(self): """Tests that the agent can run on the control suite without crashing.""" agent = svg0_prior.DistributedSVG0( environment_factory=lambda x: fakes.ContinuousEnvironment(), network_factory=make_networks, num_actors=2, batch_size=32, min_replay_size=32, max_replay_size=1000, ) program = agent.build() (learner_node, ) = program.groups['learner'] learner_node.disable_run() lp.launch(program, launch_type='test_mt') learner: acme.Learner = learner_node.create_handle().dereference() for _ in range(5): learner.step()
def test_agent(self): agent = mpo.DistributedMPO( environment_factory=lambda x: fakes.ContinuousEnvironment(bounded= True), network_factory=make_networks, num_actors=2, batch_size=32, min_replay_size=32, max_replay_size=1000, ) program = agent.build() (learner_node, ) = program.groups['learner'] learner_node.disable_run() lp.launch(program, launch_type='test_mt') learner: acme.Learner = learner_node.create_handle().dereference() for _ in range(5): learner.step()
def test_sac(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment( episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. network = networks.make_networks(spec) batch_size = 10 config = sac_config.SACConfig( batch_size=batch_size, target_entropy=sac_config.target_entropy_from_env_spec(spec), min_replay_size=1) counter = counting.Counter() agent = agents.SAC( spec=spec, network=network, config=config, seed=0, normalize_input=True, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=2)
def main(_): environment = fakes.ContinuousEnvironment(action_dim=8, observation_dim=87, episode_length=10000000) spec = specs.make_environment_spec(environment) replay_tables = make_replay_tables(spec) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') adder = make_adder(replay_client) timestep = environment.reset() adder.add_first(timestep) # TODO(raveman): Consider also filling the table to say 1M (too slow). for steps in range(10000): if steps % 1000 == 0: logging.info('Processed %s steps', steps) action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32) next_timestep = environment.step(action) adder.add(action, next_timestep, extras=()) for batch_size in [256, 256 * 8, 256 * 64]: for prefetch_size in [0, 1, 4]: print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') ds = datasets.make_reverb_dataset( table='default', server_address=replay_client.server_address, batch_size=batch_size, prefetch_size=prefetch_size, ) it = ds.as_numpy_iterator() for iteration in range(3): t = time.time() for _ in range(1000): _ = next(it) print(f'Iteration {iteration} finished in {time.time() - t}s')
def make_env(seed): del seed return fakes.ContinuousEnvironment( episode_length=10, action_dim=3, observation_dim=5, bounded=True)
def test_ail_flax(self): shutil.rmtree(flags.FLAGS.test_tmpdir) batch_size = 8 # Mujoco environment and associated demonstration dataset. environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) networks = sac.make_networks(spec=spec) config = sac.SACConfig(batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) discriminator_module = DiscriminatorModule(spec, linen.Dense(1)) def apply_fn(params: networks_lib.Params, policy_params: networks_lib.Params, state: networks_lib.Params, transitions: types.Transition, is_training: bool, rng: networks_lib.PRNGKey) -> networks_lib.Logits: del policy_params variables = dict(params=params, **state) return discriminator_module.apply(variables, transitions.observation, transitions.action, transitions.next_observation, is_training=is_training, rng=rng, mutable=state.keys()) def init_fn(rng): variables = discriminator_module.init(rng, dummy_obs, dummy_actions, dummy_obs, is_training=False, rng=rng) init_state, discriminator_params = variables.pop('params') return discriminator_params, init_state dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) dummy_actions = utils.zeros_like(spec.actions) dummy_actions = utils.add_batch_dim(dummy_actions) discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn, apply=apply_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig(is_sequence_based=False, share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name=None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) counter = counting.Counter() # Construct the agent. agent = local_layout.LocalLayout( seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size, counter=counter, ) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent, counter=counter) train_loop.run(num_episodes=1)
def test_ail(self, algo, airl_discriminator=False, subtract_logpi=False, dropout=0., lipschitz_coeff=None): shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True) batch_size = 8 # Mujoco environment and associated demonstration dataset. if algo == 'ppo': environment = fakes.DiscreteEnvironment( num_actions=NUM_DISCRETE_ACTIONS, num_observations=NUM_OBSERVATIONS, obs_shape=OBS_SHAPE, obs_dtype=OBS_DTYPE, episode_length=EPISODE_LENGTH) else: environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) if algo == 'sac': networks = sac.make_networks(spec=spec) config = sac.SACConfig( batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) elif algo == 'ppo': unroll_length = 5 distribution_value_networks = make_ppo_networks(spec) networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=unroll_length, num_minibatches=2, num_epochs=4, batch_size=batch_size) base_builder = ppo.PPOBuilder(config=config) direct_rl_batch_size = batch_size * unroll_length behavior_policy = jax.jit(ppo.make_inference_fn(networks), backend='cpu') else: raise ValueError(f'Unexpected algorithm {algo}') if subtract_logpi: assert algo == 'sac' logpi_fn = make_sac_logpi(networks) else: logpi_fn = None if algo == 'ppo': embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1]) else: embedding = lambda x: x def discriminator(*args, **kwargs) -> networks_lib.Logits: if airl_discriminator: return ail.AIRLModule( environment_spec=spec, use_action=True, use_next_obs=True, discount=.99, g_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), h_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) else: return ail.DiscriminatorModule( environment_spec=spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) discriminator_network = ail.make_discriminator( environment_spec=spec, discriminator_transformed=discriminator_transformed, logpi_fn=logpi_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig( is_sequence_based=(algo == 'ppo'), share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name='policy' if subtract_logpi else None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) # Construct the agent. agent = local_layout.LocalLayout(seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent) train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))