def test_rlrl(env_id, graph_param, seed, num_envs): adapt_envs, test_envs = make_envs(env_id, graph_param, seed, num_envs) environment_spec = specs.make_environment_spec(adapt_envs) if env_id in {'playground', 'mining'}: # spatial observation. network = snt_utils.CombinedNN(environment_spec.actions) else: network = snt_utils.RecurrentNN(environment_spec.actions) # Create meta agent. meta_agent = agents.RLRL(environment_spec=environment_spec, network=network, n_step_horizon=10, minibatch_size=10) # Run meta loop. meta_loop = environment_loop.EnvironmentMetaLoop( adapt_environment=adapt_envs, test_environment=test_envs, meta_agent=meta_agent, label='meta_train' # XXX meta train ) meta_loop.run( num_trials=3, num_adapt_steps=100, # should be greater than TimeLimit num_test_episodes=1, num_trial_splits=1)
def test_td3_fd(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment( episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. td3_network = td3.make_networks(spec) batch_size = 10 td3_config = td3.TD3Config( batch_size=batch_size, min_replay_size=1) lfd_config = lfd.LfdConfig(initial_insert_count=0, demonstration_ratio=0.2) td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config, td3_config=td3_config) counter = counting.Counter() agent = lfd.TD3fD( spec=spec, td3_network=td3_network, td3_fd_config=td3_fd_config, lfd_iterator_fn=fake_demonstration_iterator, seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def actor(self, random_key, replay, variable_source, counter, actor_id): """The actor process.""" adder = self._builder.make_adder(replay) environment_key, actor_key = jax.random.split(random_key) # Create environment and policy core. # Environments normally require uint32 as a seed. environment = self._environment_factory( utils.sample_uint32(environment_key)) networks = self._network_factory(specs.make_environment_spec(environment)) policy_network = self._policy_network(networks) actor = self._builder.make_actor(actor_key, policy_network, adder, variable_source) # Create logger and counter. counter = counting.Counter(counter, 'actor') # Only actor #0 will write to bigtable in order not to spam it too much. logger = self._actor_logger_fn(actor_id) # Create the loop to connect environment and agent. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=self._observers)
def test_ppo_nest_safety(self): # Create a fake environment with nested observations. environment = fakes.NestedDiscreteEnvironment(num_observations={ 'lat': 2, 'long': 3 }, num_actions=5, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) distribution_value_networks = make_haiku_networks(spec) ppo_networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2) workdir = self.create_tempdir() # Construct the agent. agent = ppo.PPO( spec=spec, networks=ppo_networks, config=config, seed=0, workdir=workdir.full_path, normalize_input=True, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_r2d2(self): # Create a fake environment to test with. # TODO(b/152596848): Allow R2D2 to deal with integer observations. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 4), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Construct the agent. agent = r2d2.R2D2( environment_spec=spec, network=SimpleNetwork(spec.actions), batch_size=10, samples_per_insert=2, min_replay_size=10, burn_in_length=2, trace_length=6, replay_period=4, checkpoint=False, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=5)
def test_r2d2(self): # Create a fake environment to test with. environment = fakes.fake_atari_wrapped(oar_wrapper=True) spec = specs.make_environment_spec(environment) config = r2d2.R2D2Config(batch_size=1, trace_length=5, sequence_period=1, samples_per_insert=0., min_replay_size=1, burn_in_length=1) counter = counting.Counter() agent = r2d2.R2D2( spec=spec, networks=r2d2.make_atari_networks(config.batch_size, spec), config=config, seed=0, counter=counter, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def test_mcts(self): # Create a fake environment to test with. num_actions = 5 environment = fakes.DiscreteEnvironment(num_actions=num_actions, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) network = snt.Sequential([ snt.Flatten(), snt.nets.MLP([50, 50]), networks.PolicyValueHead(spec.actions.num_values), ]) model = simulator.Simulator(environment) optimizer = snt.optimizers.Adam(1e-3) # Construct the agent. agent = mcts.MCTS(environment_spec=spec, network=network, model=model, optimizer=optimizer, n_step=1, discount=1., replay_capacity=100, num_simulations=10, batch_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def test_impala(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x, s): model = MyNetwork(spec.actions.num_values) return model(x, s) def initial_state_fn(batch_size: Optional[int] = None): model = MyNetwork(spec.actions.num_values) return model.initial_state(batch_size) # Construct the agent. agent = impala.IMPALA( environment_spec=spec, network=network, initial_state_fn=initial_state_fn, sequence_length=3, sequence_period=3, batch_size=6, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = make_environment(FLAGS.task_name) environment_spec = specs.make_environment_spec(environment) agent_networks = make_networks(environment_spec) # Construct the agent. agent = sac.SAC( environment_spec=environment_spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], encoder_network=agent_networks['observation'], #sigma=0.3, # pytype: disable=wrong-arg-types ) # Create the environment loop used for training. train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop') # Create the evaluation policy. eval_policy = agent.behavior_network # Create the evaluation actor and loop. eval_actor = actors.FeedForwardActor(policy_network=eval_policy) eval_env = make_environment(FLAGS.task_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop') for _ in range(FLAGS.num_episodes // FLAGS.num_episodes_per_eval): train_loop.run(num_episodes=FLAGS.num_episodes_per_eval)
def generate_train_data( task_name, behavior_policy_param, dataset_path, environment, dataset_size, batch_size, shuffle, include_terminal=False, # Include terminal absorbing state. ignore_d_tm1=False # Set d_tm1 as constant 1.0 if True. ): environment_spec = specs.make_environment_spec(environment) with tf.device('CPU'): behavior_policy_net = load_policy_net( task_name=task_name, params=behavior_policy_param, environment_spec=environment_spec, dataset_path=dataset_path) logging.info('start generating transitions') dataset = _generate_data(behavior_policy_net, environment, dataset_size, batch_size, shuffle, include_terminal=include_terminal, ignore_d_tm1=ignore_d_tm1) logging.info('end generating transitions') return dataset
def test_step(self): for param in TEST_PARAMS: env = self._make_parallel_environments( env_id=param.env_id, num_envs=param.num_envs, graph_param=param.graph_param) env_spec = specs.make_environment_spec(env) # Create random actor. actor = agents.RandomActor(env_spec) env.reset_task(task_index=0) ts = env.reset() # Step. action = actor.select_action(ts.observation) ts = env.step(action) self.assertIsInstance(ts, dm_env.TimeStep) self.assertFalse(any(ts.first())) ob = ts.observation self.assertEqual(ob['mask'].shape, (param.num_envs, param.action_dim)) self.assertEqual(ob['completion'].shape, (param.num_envs, param.action_dim)) self.assertEqual(ob['eligibility'].shape, (param.num_envs, param.action_dim))
def evaluator( random_key: types.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, make_actor: MakeActorFn, ): """The evaluation process.""" # Create environment and evaluator networks environment_key, actor_key = jax.random.split(random_key) # Environments normally require uint32 as a seed. environment = environment_factory(utils.sample_uint32(environment_key)) environment_spec = specs.make_environment_spec(environment) networks = network_factory(environment_spec) policy = policy_factory(networks, environment_spec, True) actor = make_actor(actor_key, policy, environment_spec, variable_source) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') logger = logger_factory('evaluator', 'actor_steps', 0) # Create the run loop and return it. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=observers)
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x): model = hk.Sequential( [hk.Flatten(), hk.nets.MLP([50, 50, spec.actions.num_values])]) return model(x) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=network, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_d4pg(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. networks = make_networks(spec) config = d4pg.D4PGConfig( batch_size=10, samples_per_insert=2, min_replay_size=10, samples_per_insert_tolerance_rate=float('inf')) counter = counting.Counter() agent = d4pg.D4PG(spec, networks, config=config, random_seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=2)
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset with next_actions extra. transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) double_transitions = rlds.transformations.batch(transitions, size=2, shift=1, drop_remainder=True) transitions = double_transitions.map(_add_next_action_extras) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = td3.make_networks(environment_spec) # Create the learner. learner = td3.TD3Learner( networks=networks, random_key=key_learner, discount=FLAGS.discount, iterator=demonstrations, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), use_sarsa_target=FLAGS.use_sarsa_target, bc_alpha=FLAGS.bc_alpha, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: del key return networks.policy_network.apply(params, observation) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) agent_networks = sac.make_networks(environment_spec) # Construct the agent. config = sac.SACConfig( target_entropy=sac.target_entropy_from_env_spec(environment_spec), num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step, seed=FLAGS.seed) agent = sac.SAC(environment_spec, agent_networks, config=config) # Create the environment loop used for training. train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop') # Create the evaluation actor and loop. eval_actor = agent.builder.make_actor( policy_network=sac.apply_policy_and_sample(agent_networks, eval_mode=True), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop') assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)
def test_dqn(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def network(x): model = hk.Sequential( [hk.Flatten(), hk.nets.MLP([50, 50, spec.actions.num_values])]) return model(x) # Make network purely functional network_hk = hk.without_apply_rng(hk.transform(network, apply_rng=True)) dummy_obs = utils.add_batch_dim(utils.zeros_like(spec.observations)) network = networks_lib.FeedForwardNetwork( init=lambda rng: network_hk.init(rng, dummy_obs), apply=network_hk.apply) # Construct the agent. agent = dqn.DQN(environment_spec=spec, network=network, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_mompo(self, distributional_critic): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10) spec = specs.make_environment_spec(environment) # Create objectives. reward_objectives, qvalue_objectives = make_objectives() num_critic_heads = len(reward_objectives) # Create networks. agent_networks = make_networks( spec.actions, num_critic_heads=num_critic_heads, distributional_critic=distributional_critic) # Construct the agent. agent = mompo.MultiObjectiveMPO( reward_objectives, qvalue_objectives, spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def test_dqfd(self): # Create a fake environment to test with. # TODO(b/152596848): Allow DQN to deal with integer observations. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Build demonstrations. dummy_action = np.zeros((), dtype=np.int32) recorder = bsuite_demonstrations.DemonstrationRecorder() timestep = environment.reset() while timestep.step_type is not dm_env.StepType.LAST: recorder.step(timestep, dummy_action) timestep = environment.step(dummy_action) recorder.step(timestep, dummy_action) recorder.record_episode() # Construct the agent. agent = dqfd.DQfD(environment_spec=spec, network=_make_network(spec.actions), demonstration_dataset=recorder.make_tf_dataset(), demonstration_ratio=0.5, batch_size=10, samples_per_insert=2, min_replay_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=10)
def evaluator( random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, make_actor: MakeActorFn, ): """The evaluation process.""" # Create environment and evaluator networks environment_key, actor_key = jax.random.split(random_key) # Environments normally require uint32 as a seed. environment = environment_factory(utils.sample_uint32(environment_key)) networks = network_factory(specs.make_environment_spec(environment)) actor = make_actor(actor_key, policy_factory(networks), variable_source) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') if logger_fn is not None: logger = logger_fn('evaluator', 'actor_steps') else: logger = loggers.make_default_logger('evaluator', log_to_bigtable, steps_key='actor_steps') # Create the run loop and return it. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=observers)
def run_ppo_agent(self, make_networks_fn): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) distribution_value_networks = make_networks_fn(spec) ppo_networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2) workdir = self.create_tempdir() counter = counting.Counter() logger = loggers.make_default_logger('learner') # Construct the agent. agent = ppo.PPO( spec=spec, networks=ppo_networks, config=config, seed=0, workdir=workdir.full_path, normalize_input=True, counter=counter, logger=logger, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def test_train(self, policy_loss_coeff_fn): seed = 0 num_iterations = 5 batch_size = 64 grad_updates_per_batch = 1 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) # Construct the learner. networks = crr.make_networks(spec, policy_layer_sizes=(8, 8), critic_layer_sizes=(8, 8)) key = jax.random.PRNGKey(seed) dataset = fakes.transition_iterator(environment) learner = crr.CRRLearner(networks, key, discount=0.95, target_update_period=2, policy_loss_coeff_fn=policy_loss_coeff_fn, iterator=dataset(batch_size * grad_updates_per_batch), policy_optimizer=optax.adam(1e-4), critic_optimizer=optax.adam(1e-4), grad_updates_per_batch=grad_updates_per_batch) # Train the learner. for _ in range(num_iterations): learner.step()
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = make_environment() environment_spec = specs.make_environment_spec(environment) agent_networks = make_networks(environment_spec.actions) # Construct the agent. agent = d4pg.D4PG( environment_spec=environment_spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], observation_network=agent_networks['observation'], # pytype: disable=wrong-arg-types ) # Create the environment loop used for training. train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop') # Create the evaluation policy. eval_policy = snt.Sequential([ agent_networks['observation'], agent_networks['policy'], ]) # Create the evaluation actor and loop. eval_actor = actors.FeedForwardActor(policy_network=eval_policy) eval_env = make_environment() eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop') for _ in range(FLAGS.num_episodes // FLAGS.num_episodes_per_eval): train_loop.run(num_episodes=FLAGS.num_episodes_per_eval) eval_loop.run(num_episodes=1)
def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): return hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), lambda x: jnp.argmax(x, axis=-1), ])( inputs) policy = hk.transform(policy, apply_rng=True) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.FeedForwardActor( policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def main(_): # Create an environment and environment model. environment, model = make_env_and_model( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment_spec = specs.make_environment_spec(environment) # Create the network and optimizer. network = make_network(environment_spec.actions) optimizer = snt.optimizers.Adam(learning_rate=1e-3) # Construct the agent. agent = mcts.MCTS( environment_spec=environment_spec, model=model, network=network, optimizer=optimizer, discount=0.99, replay_capacity=10000, n_step=1, batch_size=16, num_simulations=50, ) # Run the environment loop. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error
def test_train(self): seed = 0 num_iterations = 2 batch_size = 64 # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True, action_dim=6) spec = specs.make_environment_spec(environment) # Construct the agent. networks = cql.make_networks(spec, hidden_layer_sizes=(8, 8)) dataset = fakes.transition_iterator(environment) key = jax.random.PRNGKey(seed) learner = cql.CQLLearner(batch_size, networks, key, demonstrations=dataset(batch_size), policy_optimizer=optax.adam(3e-5), critic_optimizer=optax.adam(3e-4), cql_lagrange_threshold=1., target_entropy=0.1, num_sgd_steps_per_step=1) # Train the agent for _ in range(num_iterations): learner.step()
def replay(self): """The replay storage.""" dummy_seed = 1 environment_spec = ( self._environment_spec or specs.make_environment_spec(self._environment_factory(dummy_seed))) return self._builder.make_replay_tables(environment_spec)
def test_make_dataset_with_sequence_length_and_batch_size(self): sequence_length = 6 batch_size = 4 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, batch_size=batch_size, sequence_length=sequence_length) def make_tensor_spec(spec): return tf.TensorSpec(shape=( batch_size, sequence_length, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) expected_spec = adders.Step(observation=expected_spec.observations, action=expected_spec.actions, reward=expected_spec.rewards, discount=expected_spec.discounts, start_of_episode=specs.Array( shape=(batch_size, sequence_length), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def save_ilp(environment, filename): environment_spec = specs.make_environment_spec(environment) # Create actor meta_agent = msgi.MSGIRandom(environment_spec=environment_spec, num_adapt_steps=NUM_ADAPT_STEPS, visualize=False, directory=None, environment_id=ENV_NAME) agent = meta_agent.instantiate_adapt_agent() # MSGIActor() (Fast agent) # Set task environment.reset_task(task_index=TASK_IDX) meta_agent.reset_agent(environment=environment) # Run RL loop ### Reset timestep = environment.reset() agent.observe_first(timestep) ### Loop cumulative_reward = 0. for step_count in range(NUM_ADAPT_STEPS): action = agent.select_action(timestep.observation) print(action) timestep = environment.step(action) agent.observe(action, next_timestep=timestep) # Book-keeping. cumulative_reward += sum(timestep.reward) # TODO: needs check. meta_agent._ilp.save(filename)
def test_msgi(env_id, graph_param, seed, num_envs): adapt_envs, test_envs = make_envs(env_id, graph_param, seed, num_envs) environment_spec = specs.make_environment_spec(adapt_envs) # Create meta agent. meta_agent = agents.MSGI(environment_spec=environment_spec, num_adapt_steps=20, num_trial_splits=5, environment_id=env_id, branch_neccessary_first=True, exploration='random', temp=200, w_a=3.0, beta_a=8.0, ep_or=0.8, temp_or=2.0) # Run meta loop. meta_loop = environment_loop.EnvironmentMetaLoop( adapt_environment=adapt_envs, test_environment=test_envs, meta_agent=meta_agent, label='meta_eval') meta_loop.run( num_trials=3, num_adapt_steps=20, # should be greater than TimeLimit num_test_episodes=1, num_trial_splits=5)