def test_dqfd(self): # Create a fake environment to test with. # TODO(b/152596848): Allow DQN to deal with integer observations. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Build demonstrations. dummy_action = np.zeros((), dtype=np.int32) recorder = bsuite_demonstrations.DemonstrationRecorder() timestep = environment.reset() while timestep.step_type is not dm_env.StepType.LAST: recorder.step(timestep, dummy_action) timestep = environment.step(dummy_action) recorder.step(timestep, dummy_action) recorder.record_episode() # Construct the agent. agent = dqfd.DQfD(environment_spec=spec, network=_make_network(spec.actions), demonstration_dataset=recorder.make_tf_dataset(), demonstration_ratio=0.5, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=10)
def test_mcts(self): # Create a fake environment to test with. num_actions = 5 environment = fakes.DiscreteEnvironment(num_actions=num_actions, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) network = snt.Sequential([ snt.Flatten(), snt.nets.MLP([50, 50]), networks.PolicyValueHead(spec.actions.num_values), ]) model = simulator.Simulator(environment) optimizer = snt.optimizers.Adam(1e-3) # Construct the agent. agent = mcts.MCTS(environment_spec=spec, network=network, model=model, optimizer=optimizer, n_step=1, discount=1., replay_capacity=100, num_simulations=10, batch_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x): model = hk.Sequential( [hk.Flatten(), hk.nets.MLP([50, 50, spec.actions.num_values])]) return model(x) # Make network purely functional network_hk = hk.without_apply_rng(hk.transform(network, apply_rng=True)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) network = networks_lib.FeedForwardNetwork( init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=network, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_reset(self): """Ensure that noop starts `reset` steps the environment multiple times.""" noop_action = 0 noop_max = 10 seed = 24 base_env = fakes.DiscreteEnvironment(action_dtype=np.int64, obs_dtype=np.int64, reward_spec=specs.Array( dtype=np.float64, shape=())) mock_step_fn = mock.MagicMock() expected_num_step_calls = np.random.RandomState(seed).randint( noop_max + 1) with mock.patch.object(base_env, 'step', mock_step_fn): env = wrappers.NoopStartsWrapper( base_env, noop_action=noop_action, noop_max=noop_max, seed=seed, ) env.reset() # Test environment step called with noop action as part of wrapper.reset mock_step_fn.assert_called_with(noop_action) self.assertEqual(mock_step_fn.call_count, expected_num_step_calls) self.assertEqual(mock_step_fn.call_args, ((noop_action, ), {}))
def run_ppo_agent(self, make_networks_fn): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) distribution_value_networks = make_networks_fn(spec) ppo_networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2) workdir = self.create_tempdir() counter = counting.Counter() logger = loggers.make_default_logger('learner') # Construct the agent. agent = ppo.PPO( spec=spec, networks=ppo_networks, config=config, seed=0, workdir=workdir.full_path, normalize_input=True, counter=counter, logger=logger, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def test_impala(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x, s): model = MyNetwork(spec.actions.num_values) return model(x, s) def initial_state_fn(batch_size: Optional[int] = None): model = MyNetwork(spec.actions.num_values) return model.initial_state(batch_size) # Construct the agent. agent = impala.IMPALA( environment_spec=spec, network=network, initial_state_fn=initial_state_fn, sequence_length=3, sequence_period=3, batch_size=6, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_r2d2(self): # Create a fake environment to test with. # TODO(b/152596848): Allow R2D2 to deal with integer observations. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 4), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Construct the agent. agent = r2d2.R2D2( environment_spec=spec, network=SimpleNetwork(spec.actions), batch_size=10, samples_per_insert=2, min_replay_size=10, burn_in_length=2, trace_length=6, replay_period=4, checkpoint=False, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=5)
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x): model = hk.Sequential( [hk.Flatten(), hk.nets.MLP([50, 50, spec.actions.num_values])]) return model(x) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=network, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_environment_loop(self): # Create the actor/environment and stick them in a loop. environment = fakes.DiscreteEnvironment(episode_length=10) actor = fakes.Actor(specs.make_environment_spec(environment)) loop = environment_loop.EnvironmentLoop(environment, actor) # Run the loop. There should be episode_length+1 update calls per episode. loop.run(num_episodes=10) self.assertEqual(actor.num_updates, 100)
def test_pickle_unpickle(self): test_env = base.EnvironmentWrapper( environment=fakes.DiscreteEnvironment()) test_env_pickled = pickle.dumps(test_env) test_env_restored = pickle.loads(test_env_pickled) self.assertEqual( test_env.observation_spec(), test_env_restored.observation_spec(), )
def test_raises_value_error(self): """Ensure that wrapper raises error if noop_max is <0.""" base_env = fakes.DiscreteEnvironment(action_dtype=np.int64, obs_dtype=np.int64, reward_spec=specs.Array( dtype=np.float64, shape=())) with self.assertRaises(ValueError): wrappers.NoopStartsWrapper(base_env, noop_action=0, noop_max=-1, seed=24)
def test_discrete_actions(self, loss_name): with chex.fake_pmap_and_jit(): num_sgd_steps_per_step = 1 num_steps = 5 # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=10, num_observations=100, obs_shape=(10, ), obs_dtype=np.float32) spec = specs.make_environment_spec(environment) dataset_demonstration = fakes.transition_dataset(environment) dataset_demonstration = dataset_demonstration.map( lambda sample: types.Transition(*sample.data)) dataset_demonstration = dataset_demonstration.batch( 8).as_numpy_iterator() # Construct the agent. network = make_networks(spec, discrete_actions=True) def logp_fn(logits, actions): max_logits = jnp.max(logits, axis=-1, keepdims=True) logits = logits - max_logits logits_actions = jnp.sum( jax.nn.one_hot(actions, spec.actions.num_values) * logits, axis=-1) log_prob = logits_actions - special.logsumexp(logits, axis=-1) return log_prob if loss_name == 'logp': loss_fn = bc.logp(logp_fn=logp_fn) elif loss_name == 'rcal': base_loss_fn = bc.logp(logp_fn=logp_fn) loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1) else: raise ValueError learner = bc.BCLearner( network=network, random_key=jax.random.PRNGKey(0), loss_fn=loss_fn, optimizer=optax.adam(0.01), demonstrations=dataset_demonstration, num_sgd_steps_per_step=num_sgd_steps_per_step) # Train the agent for _ in range(num_steps): learner.step()
def test_impala(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def forward_fn(x, s): model = MyNetwork(spec.actions.num_values) return model(x, s) def initial_state_fn(batch_size: Optional[int] = None): model = MyNetwork(spec.actions.num_values) return model.initial_state(batch_size) def unroll_fn(inputs, state): model = MyNetwork(spec.actions.num_values) return hk.static_unroll(model, inputs, state) # We pass pure, Haiku-agnostic functions to the agent. forward_fn_transformed = hk.without_apply_rng( hk.transform(forward_fn, apply_rng=True)) unroll_fn_transformed = hk.without_apply_rng( hk.transform(unroll_fn, apply_rng=True)) initial_state_fn_transformed = hk.without_apply_rng( hk.transform(initial_state_fn, apply_rng=True)) # Construct the agent. config = impala_agent.IMPALAConfig( sequence_length=3, sequence_period=3, batch_size=6, ) agent = impala.IMPALAFromConfig( environment_spec=spec, forward_fn=forward_fn_transformed.apply, initial_state_init_fn=initial_state_fn_transformed.init, initial_state_fn=initial_state_fn_transformed.apply, unroll_init_fn=unroll_fn_transformed.init, unroll_fn=unroll_fn_transformed.apply, config=config, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_discrete(self): env = wrappers.SinglePrecisionWrapper( fakes.DiscreteEnvironment( action_dtype=np.int64, obs_dtype=np.int64, reward_dtype=np.float64)) self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.int32)) self.assertTrue(np.issubdtype(env.action_spec().dtype, np.int32)) self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32)) self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32)) timestep = env.reset() self.assertEqual(timestep.reward, None) self.assertEqual(timestep.discount, None) self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32)) timestep = env.step(0) self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32)) self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32)) self.assertTrue(np.issubdtype(timestep.observation.dtype, np.int32))
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=_make_network(spec.actions), batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def _parameterized_setup(discount_spec: Optional[types.NestedSpec] = None, reward_spec: Optional[types.NestedSpec] = None): """Common setup code that, unlike self.setUp, takes arguments. Args: discount_spec: None, or a (nested) specs.BoundedArray. reward_spec: None, or a (nested) specs.Array. Returns: environment, actor, loop """ env_kwargs = {'episode_length': EPISODE_LENGTH} if discount_spec: env_kwargs['discount_spec'] = discount_spec if reward_spec: env_kwargs['reward_spec'] = reward_spec environment = fakes.DiscreteEnvironment(**env_kwargs) actor = fakes.Actor(specs.make_environment_spec(environment)) loop = environment_loop.EnvironmentLoop(environment, actor) return actor, loop
def test_impala(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Construct the agent. agent = impala.IMPALA( environment_spec=spec, network=_make_network(spec.actions), sequence_length=3, sequence_period=3, batch_size=6, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_full_learner(self): # Create dataset. environment = fakes.DiscreteEnvironment( num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) dataset = fakes.transition_dataset(environment).batch( 2, drop_remainder=True) # Build network. network = networks.IQNNetwork( torso=_make_torso_network(num_outputs=2), head=_make_head_network(num_outputs=spec.actions.num_values), latent_dim=2, num_quantile_samples=1) tf2_utils.create_variables(network, [spec.observations]) # Build learner. counter = counting.Counter() learner = iqn.IQNLearner( network=network, target_network=copy.deepcopy(network), dataset=dataset, learning_rate=1e-4, discount=0.99, importance_sampling_exponent=0.2, target_update_period=1, counter=counter) # Run a learner step. learner.step() # Check counts from IQN learner. counts = counter.get_counts() self.assertEqual(1, counts['steps']) # Check learner state. self.assertEqual(1, learner.state['num_steps'].numpy())
def test_full_learner(self): # Create dataset. environment = fakes.DiscreteEnvironment( num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) dataset = fakes.transition_dataset(environment).batch(2) # Build network. g_network = _make_network(spec.actions) q_network = _make_network(spec.actions) network = discrete_networks.DiscreteFilteredQNetwork(g_network=g_network, q_network=q_network, threshold=0.5) tf2_utils.create_variables(network, [spec.observations]) # Build learner. counter = counting.Counter() learner = bcq.DiscreteBCQLearner( network=network, dataset=dataset, learning_rate=1e-4, discount=0.99, importance_sampling_exponent=0.2, target_update_period=100, counter=counter) # Run a learner step. learner.step() # Check counts from BC and BCQ learners. counts = counter.get_counts() self.assertEqual(1, counts['bc_steps']) self.assertEqual(1, counts['bcq_steps']) # Check learner state. self.assertEqual(1, learner.state['bc_num_steps'].numpy()) self.assertEqual(1, learner.state['bcq_num_steps'].numpy())
def test_r2d3(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Build demonstrations. dummy_action = np.zeros((), dtype=np.int32) recorder = bsuite_demonstrations.DemonstrationRecorder() timestep = environment.reset() while timestep.step_type is not dm_env.StepType.LAST: recorder.step(timestep, dummy_action) timestep = environment.step(dummy_action) recorder.step(timestep, dummy_action) recorder.record_episode() # Construct the agent. agent = r2d3.R2D3( environment_spec=spec, network=SimpleNetwork(spec.actions), target_network=SimpleNetwork(spec.actions), demonstration_dataset=recorder.make_tf_dataset(), demonstration_ratio=0.5, batch_size=10, samples_per_insert=2, min_replay_size=10, burn_in_length=2, trace_length=6, replay_period=4, checkpoint=False, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=5)
def test_ail(self, algo, airl_discriminator=False, subtract_logpi=False, dropout=0., lipschitz_coeff=None): shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True) batch_size = 8 # Mujoco environment and associated demonstration dataset. if algo == 'ppo': environment = fakes.DiscreteEnvironment( num_actions=NUM_DISCRETE_ACTIONS, num_observations=NUM_OBSERVATIONS, obs_shape=OBS_SHAPE, obs_dtype=OBS_DTYPE, episode_length=EPISODE_LENGTH) else: environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) if algo == 'sac': networks = sac.make_networks(spec=spec) config = sac.SACConfig( batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) elif algo == 'ppo': unroll_length = 5 distribution_value_networks = make_ppo_networks(spec) networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=unroll_length, num_minibatches=2, num_epochs=4, batch_size=batch_size) base_builder = ppo.PPOBuilder(config=config) direct_rl_batch_size = batch_size * unroll_length behavior_policy = jax.jit(ppo.make_inference_fn(networks), backend='cpu') else: raise ValueError(f'Unexpected algorithm {algo}') if subtract_logpi: assert algo == 'sac' logpi_fn = make_sac_logpi(networks) else: logpi_fn = None if algo == 'ppo': embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1]) else: embedding = lambda x: x def discriminator(*args, **kwargs) -> networks_lib.Logits: if airl_discriminator: return ail.AIRLModule( environment_spec=spec, use_action=True, use_next_obs=True, discount=.99, g_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), h_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) else: return ail.DiscriminatorModule( environment_spec=spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) discriminator_network = ail.make_discriminator( environment_spec=spec, discriminator_transformed=discriminator_transformed, logpi_fn=logpi_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig( is_sequence_based=(algo == 'ppo'), share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name='policy' if subtract_logpi else None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) # Construct the agent. agent = local_layout.LocalLayout(seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent) train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))
def setUp(self): super().setUp() # Create the actor/environment and stick them in a loop. environment = fakes.DiscreteEnvironment(episode_length=EPISODE_LENGTH) self.actor = fakes.Actor(specs.make_environment_spec(environment)) self.loop = environment_loop.EnvironmentLoop(environment, self.actor)
def test_deepcopy(self): test_env = base.EnvironmentWrapper( environment=fakes.DiscreteEnvironment()) copied_env = copy.deepcopy(test_env) del copied_env