def main(_): # Create an environment and grab the spec. environment = make_environment() environment_spec = specs.make_environment_spec(environment) # Create the networks to optimize. network = make_network(environment_spec.actions) agent = impala.IMPALA( environment_spec=environment_spec, network=network, sequence_length=3, sequence_period=3, ) # Run the environment loop. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error
def test_impala(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Construct the agent. agent = impala.IMPALA( environment_spec=spec, network=_make_network(spec.actions), sequence_length=3, sequence_period=3, batch_size=6, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = wrappers.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Create the networks to optimize. network = make_network(environment_spec.actions) agent = impala.IMPALA( environment_spec=environment_spec, network=network, sequence_length=3, sequence_period=3, ) # Run the environment loop. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error