def test_reward_monitoring(): """ Reward logging unit test """ # instantiate env env = build_dummy_maze_env() env = MazeEnvMonitoringWrapper.wrap(env, observation_logging=False, action_logging=False, reward_logging=True) env = LogStatsWrapper.wrap(env) # for accessing events from previous steps env.reset() env.step(env.action_space.sample()) # test application of wrapper for ii in range(2): env.step(env.action_space.sample()) reward_events = env.get_last_step_events(query=[ RewardEvents.reward_original, RewardEvents.reward_processed ]) assert len(reward_events) == 2 for event in reward_events: assert issubclass(event.interface_class, RewardEvents) assert event.attributes['value'] == 10 assert event.interface_method in [ RewardEvents.reward_original, RewardEvents.reward_processed ]
def test_observation_monitoring(): """ Observation logging unit test """ # instantiate env env = build_dummy_maze_env() env = MazeEnvMonitoringWrapper.wrap(env, observation_logging=True, action_logging=False, reward_logging=False) env = LogStatsWrapper.wrap(env) # for accessing events from previous steps env.reset() # test application of wrapper for ii in range(3): # Observation will get reported in the next step (when the agent is actually acting on it) obs = env.step(env.action_space.sample())[0] observation_events = env.get_last_step_events(query=[ ObservationEvents.observation_original, ObservationEvents.observation_processed ]) assert len(observation_events) == 4 for event in observation_events: assert issubclass(event.interface_class, ObservationEvents) obs_name = event.attributes['name'] assert obs_name in ['observation_0', 'observation_1'] if ii > 0: assert np.allclose(np.asarray(obs[obs_name]), np.asarray(event.attributes['value']))
def test_action_monitoring(): """ Action logging unit test """ # instantiate env env = build_dummy_maze_env() env = MazeEnvMonitoringWrapper.wrap(env, observation_logging=False, action_logging=True, reward_logging=False) env = LogStatsWrapper.wrap(env) # for accessing events from previous steps env.reset() # test application of wrapper for ii in range(2): env.step(env.action_space.sample()) action_events = env.get_last_step_events(query=[ ActionEvents.discrete_action, ActionEvents.continuous_action, ActionEvents.multi_binary_action ]) assert len(action_events) == 7 for event in action_events: if event.attributes['name'] in [ 'action_0_0', 'action_0_1_0', 'action_0_1_1', 'action_1_0' ]: assert event.interface_method == ActionEvents.discrete_action elif event.attributes['name'] in ['action_0_2', 'action_2_0']: assert event.interface_method == ActionEvents.continuous_action elif event.attributes['name'] in ['action_1_1']: assert event.interface_method == ActionEvents.multi_binary_action else: raise ValueError
def test_supports_multi_step_wrappers(): env = build_dummy_structured_env() env = LogStatsWrapper.wrap(env) agent_deployment = AgentDeployment( policy=DummyGreedyPolicy(), env=env ) # Step the environment manually here and query the agent integration wrapper for maze_actions test_core_env = build_dummy_structured_env().core_env maze_state = test_core_env.reset() reward, done, info = 0, False, {} for i in range(4): maze_action = agent_deployment.act(maze_state, reward, done, info) maze_state, reward, done, info = test_core_env.step(maze_action) agent_deployment.close(maze_state, reward, done, info) assert env.get_stats_value( BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count" ) == 4 # Step count is still 4 event with multiple sub-steps, as it is detected based on env time assert env.get_stats_value( RewardEvents.reward_original, LogStatsLevel.EPOCH, name="total_step_count" ) == 4
def test_records_stats(): env = LogStatsWrapper.wrap(build_dummy_maze_env()) agent_deployment = AgentDeployment( policy=DummyGreedyPolicy(), env=env ) # Step the environment manually here and query the agent integration wrapper for maze_actions test_core_env = build_dummy_maze_env().core_env maze_state = test_core_env.reset() reward, done, info = 0, False, {} for i in range(5): maze_action = agent_deployment.act(maze_state, reward, done, info) maze_state, reward, done, info = test_core_env.step(maze_action) agent_deployment.close(maze_state, reward, done, info) assert env.get_stats_value( RewardEvents.reward_original, LogStatsLevel.EPOCH, name="total_step_count" ) == 5 assert env.get_stats_value( BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count" ) == 5
def __init__(self, env: Union[MazeEnv, StructuredVectorEnv], record_logits: bool = False, record_step_stats: bool = False, record_episode_stats: bool = False, record_next_observations: bool = False, terminate_on_done: bool = False): self.env = env self.is_vectorized = isinstance(self.env, StructuredVectorEnv) self.record_logits = record_logits self.record_step_stats = record_step_stats self.record_episode_stats = record_episode_stats self.record_next_observations = record_next_observations self.terminate_on_done = terminate_on_done if (self.record_step_stats or self.record_episode_stats) and not isinstance( self.env, LogStatsWrapper): self.env = LogStatsWrapper.wrap(self.env) self.step_keys = list(env.observation_spaces_dict.keys() ) # Only synchronous environments are supported self.last_observation = None # Keep last observations and do not reset envs between rollouts self.rollout_counter = 0 # For generating trajectory IDs if none are supplied
def train(env): """ unit test helper function """ n_episodes = 10 n_steps_per_episode = 5 # setup logging writer = LogStatsWriterTensorboard(log_dir='test_log', tensorboard_render_figure=True) register_log_stats_writer(writer) # attach a console writer as well for immediate console feedback register_log_stats_writer(LogStatsWriterConsole()) env = LogStatsWrapper.wrap(env, logging_prefix="train") with SimpleStatsLoggingSetup(env): for episode in range(n_episodes): _ = env.reset() for step in range(n_steps_per_episode): # take random action action = env.action_space.sample() # take step in env and trigger log stats writing _, _, _, _ = env.step(action) # test accessing stats env.get_stats(LogStatsLevel.EPOCH) env.get_stats_value(BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="mean")
def run_with(self, env: ConfigType, wrappers: CollectionOfConfigType, agent: ConfigType): """Run the rollout sequentially in the main process.""" env, agent = self.init_env_and_agent( env, wrappers, self.max_episode_steps, agent, self.input_dir, self.maze_seeding.generate_env_instance_seed(), self.maze_seeding.generate_agent_instance_seed()) # Set up the wrappers # Hydra handles working directory register_log_stats_writer(LogStatsWriterConsole()) if not isinstance(env, LogStatsWrapper): env = LogStatsWrapper.wrap(env, logging_prefix="rollout_data") if self.record_event_logs: LogEventsWriterRegistry.register_writer( LogEventsWriterTSV(log_dir="./event_logs")) if self.record_trajectory: TrajectoryWriterRegistry.register_writer( TrajectoryWriterFile(log_dir="./trajectory_data")) if not isinstance(env, TrajectoryRecordingWrapper): env = TrajectoryRecordingWrapper.wrap(env) self.progress_bar = tqdm(desc="Episodes done", unit=" episodes", total=self.n_episodes) RolloutRunner.run_interaction_loop( env, agent, self.n_episodes, render=self.render, episode_end_callback=lambda: self.update_progress()) self.progress_bar.close() env.write_epoch_stats()
def env_factory() -> EnvType: """Create a new env in order to apply the monkey patch in worker-process. :return: A env factory """ env = EnvFactory(cfg.env, cfg.wrappers if 'wrappers' in cfg else {})() logging_env = LogStatsWrapper.wrap(env) return logging_env
def __init__(self, env: StructuredEnv, n_eval_rollouts: int, shared_noise: SharedNoiseTable, agent_instance_seed: int): env = TimeLimitWrapper.wrap(env) env = LogStatsWrapper.wrap(env) self.env = ESRolloutWorkerWrapper.wrap( env=env, shared_noise=shared_noise, agent_instance_seed=agent_instance_seed) self.n_eval_rollouts = n_eval_rollouts
def __init__(self, env: T, log_dir: str = '.'): """ :param env: The environment to monitor :param log_dir: Where to log the monitoring data """ self.env = env self.log_dir = log_dir # Wrap the env to enable stats, events, and trajectory data logging self.env = LogStatsWrapper.wrap(self.env, logging_prefix="eval") self.env = TrajectoryRecordingWrapper.wrap(self.env)
def test_observation_monitoring(): """ Observation logging unit test """ env = GymMazeEnv(env="CartPole-v0") env = ObservationVisualizationWrapper.wrap(env, plot_function=None) env = LogStatsWrapper.wrap(env, logging_prefix="train") with SimpleStatsLoggingSetup(env, log_dir="."): env.reset() done = False while not done: obs, rew, done, info = env.step(env.action_space.sample())
def __init__(self, env_factories: List[Callable[[], MazeEnv]], logging_prefix: Optional[str] = None): self.envs = [ LogStatsWrapper.wrap(env_fn()) for env_fn in env_factories ] super().__init__( n_envs=len(env_factories), action_spaces_dict=self.envs[0].action_spaces_dict, observation_spaces_dict=self.envs[0].observation_spaces_dict, agent_counts_dict=self.envs[0].agent_counts_dict, logging_prefix=logging_prefix)
def test_records_events_in_reset(): env = build_dummy_maze_env() env = _EventsInResetWrapper.wrap(env) env = LogStatsWrapper.wrap(env) env.reset() for i in range(5): env.step(env.action_space.sample()) env.write_epoch_stats() assert env.get_stats_value( BaseEnvEvents.test_event, LogStatsLevel.EPOCH ) == 1 # only from the single event fired during env reset
def _run_rollout_loop(env: Union[BaseEnv, gym.Env], n_steps_per_episode: int, n_episodes: int, writer: LogEventsWriter): LogEventsWriterRegistry.writers = [] # Ensure there is no other writer LogEventsWriterRegistry.register_writer(writer) env = LogStatsWrapper.wrap(env) for _ in range(n_episodes): env.reset() for i in range(n_steps_per_episode): env.step(env.action_space.sample()) if i == n_steps_per_episode - 1: env.render_stats() env.close() env.write_epoch_stats() # Test ending without env reset
def __init__(self, policy: ConfigType, env: ConfigType, wrappers: CollectionOfConfigType = None, num_candidates: int = 1): self.rollout_done = False # Thread synchronisation self.state_queue = Queue(maxsize=1) self.maze_action_queue = Queue(maxsize=1) self.rollout_done_event = Event() # Build simulation env from config (like we would do during training), then swap core env for external one self.env = EnvFactory(env, wrappers if wrappers else {})() if not isinstance(self.env, LogStatsWrapper): self.env = LogStatsWrapper.wrap(self.env) self.external_core_env = ExternalCoreEnv( context=self.env.core_env.context, state_queue=self.state_queue, maze_action_queue=self.maze_action_queue, rollout_done_event=self.rollout_done_event, renderer=self.env.core_env.get_renderer()) # Due to the fake subclass hierarchy generated in each Wrapper, we need to make sure # we swap the core env directly on the MazeEnv, not on any wrapper above it maze_env = self.env while isinstance(maze_env.env, MazeEnv): maze_env = maze_env.env maze_env.core_env = self.external_core_env # If we are working with multiple candidate actions, wrap the action_conversion interfaces if num_candidates > 1: for policy_id, action_conversion in self.env.action_conversion_dict.items( ): self.env.action_conversion_dict[ policy_id] = ActionConversionCandidatesInterface( action_conversion) # Policy executor, running the rollout loop on a separate thread self.policy = Factory(base_type=Policy).instantiate(policy) self.policy_executor = PolicyExecutor( env=self.env, policy=self.policy, rollout_done_event=self.rollout_done_event, exception_queue=self.maze_action_queue, num_candidates=num_candidates) self.policy_thread = Thread( target=self.policy_executor.run_rollout_loop, daemon=True) self.policy_thread.start()
def test_records_policy_events(): env = build_dummy_maze_env() env = LogStatsWrapper.wrap(env) base_events = env.core_env.context.event_service.create_event_topic( BaseEnvEvents) env.reset() for i in range(5): base_events.test_event( 1) # Simulate firing event from policy (= outside of env.step) env.step(env.action_space.sample()) env.write_epoch_stats() assert env.get_stats_value( BaseEnvEvents.test_event, LogStatsLevel.EPOCH) == 5 # value of 1 x 5 steps
def test_records_stats(): # the default simple setup: flat, single-step env, no step skipping etc. env = build_dummy_maze_env() env = LogStatsWrapper.wrap(env) env.reset() for i in range(5): env.step(env.action_space.sample()) # both step counts seen from outside and seen from core env should correspond to 5 env.write_epoch_stats() assert env.get_stats_value(RewardEvents.reward_original, LogStatsLevel.EPOCH, name="total_step_count") == 5 assert env.get_stats_value(BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count") == 5
def test_step_increment_in_single_step_core_env(): """In single sub-step envs, events should be cleared out and env time incremented automatically.""" env = build_dummy_maze_env() env = LogStatsWrapper.wrap(env) env.reset() assert env.get_env_time() == 0 # 10 steps for _ in range(10): env.step(env.action_space.sample()) assert env.get_env_time() == 10 env.reset() increment_log_step() assert env.get_stats_value(BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count") == 10
def _setup_monitoring(env: StructuredEnv, record_trajectory: bool) -> Tuple[StructuredEnv, EpisodeRecorder]: """Set up monitoring wrappers. Stats and event logs are collected in the episode recorder, so that they can be shipped to the main process on end of each episode. """ if not isinstance(env, LogStatsWrapper): env = LogStatsWrapper.wrap(env) episode_recorder = EpisodeRecorder() LogEventsWriterRegistry.register_writer(episode_recorder) env.episode_stats.register_consumer(episode_recorder) # Trajectory recording happens in the worker process # Hydra handles working directory if record_trajectory: TrajectoryWriterRegistry.register_writer(TrajectoryWriterFile(log_dir="./trajectory_data")) if not isinstance(env, TrajectoryRecordingWrapper): env = TrajectoryRecordingWrapper.wrap(env) return env, episode_recorder
def _build_env(): env = DummyEnvironment(core_env=DummyCoreEnvironment( gym.spaces.Discrete(10)), action_conversion=[{ "_target_": DoubleActionConversion }], observation_conversion=[{ "_target_": DoubleObservationConversion }]) env = _DummyActionWrapper.wrap(env) env = _DummyObservationWrapper.wrap(env) env = _DummyRewardWrapper.wrap(env) env = TimeLimitWrapper.wrap(env) env = LogStatsWrapper.wrap(env) env = TrajectoryRecordingWrapper.wrap(env) return env
def test_step_increment_in_structured_core_environments(): """Structured core envs manage the step incrementing themselves and Maze env should not interfere with that.""" env = build_dummy_maze_env_with_structured_core_env() env = LogStatsWrapper.wrap(env) env.reset() assert env.get_env_time() == 0 # Do 10 agent steps => 5 structured steps (as we have two agents) for _ in range(10): env.step(env.action_space.sample()) assert env.get_env_time() == 5 env.reset() increment_log_step() assert env.get_stats_value(BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count") == 5
def test_observation_skipping_wrapper_sticky_flat(): """ Step skipping unit test """ n_steps = 3 # instantiate env env = GymMazeEnv("CartPole-v0") env = StepSkipWrapper.wrap(env, n_steps=n_steps, skip_mode='sticky') env = LogStatsWrapper.wrap(env) # for accessing events from previous steps # reset environment and run interaction loop env.reset() cum_rew = 0 for i in range(2): action = env.action_space.sample() obs, reward, done, info = env.step(action) cum_rew += reward assert len(env.get_last_step_events(query=RewardEvents.reward_original)) == 1 assert cum_rew == 6
def test_handles_step_skipping_in_reset(): env = build_dummy_maze_env() env = _StepInResetWrapper.wrap(env) env = LogStatsWrapper.wrap(env) env.reset() # Step the env once (should be the third step -- first two were done in the reset) env.step(env.action_space.sample()) # Events should be collected for 3 steps in total -- two from the env reset done by the wrapper + one done above assert len(env.episode_event_log.step_event_logs) == 3 # The same goes for "original reward" stats env.write_epoch_stats() assert env.get_stats_value(RewardEvents.reward_original, LogStatsLevel.EPOCH, name="total_step_count") == 3 # The step count from outside is still one (as normal reward events should not be fired for "skipped" steps) assert env.get_stats_value(BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count") == 1
def test_handles_multi_step_setup(): env = build_dummy_structured_env() env = LogStatsWrapper.wrap(env) # Step the env four times (should correspond to two core-env steps) env.reset() for i in range(4): env.step(env.action_space.sample()) # => events should be collected for 2 steps in total assert len(env.episode_event_log.step_event_logs) == 2 # The same goes for both reward stats from outside and from core-env perspective env.write_epoch_stats() assert env.get_stats_value(RewardEvents.reward_original, LogStatsLevel.EPOCH, name="total_step_count") == 2 assert env.get_stats_value(BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count") == 2
def test_handles_step_skipping_in_step(): env = build_dummy_maze_env() env = _StepInStepWrapper.wrap(env) env = LogStatsWrapper.wrap(env) # Step the env twice (should correspond to four core-env steps) env.reset() for i in range(2): env.step(env.action_space.sample()) # => events should be collected for 4 steps in total assert len(env.episode_event_log.step_event_logs) == 4 # The same goes for "original reward" stats env.write_epoch_stats() assert env.get_stats_value(RewardEvents.reward_original, LogStatsLevel.EPOCH, name="total_step_count") == 4 # The step count from outside is still just two (as normal reward events should not be fired for "skipped" steps) assert env.get_stats_value(BaseEnvEvents.reward, LogStatsLevel.EPOCH, name="total_step_count") == 2
def __init__(self, pickled_env_factory: bytes, pickled_policy: bytes, shared_noise: SharedNoiseTable, output_queue: multiprocessing.Queue, broadcasting_container: BroadcastingContainer, env_seed: int, agent_seed: int, is_eval_worker: bool): self.policy = cloudpickle.loads(pickled_policy) self.policy_version_counter = -1 self.aux_data = None self.output_queue = output_queue self.broadcasting_container = broadcasting_container self.is_eval_worker = is_eval_worker env_factory = cloudpickle.loads(pickled_env_factory) self.env = env_factory() self.env = TimeLimitWrapper.wrap(self.env) if not isinstance(self.env, LogStatsWrapper): self.env = LogStatsWrapper.wrap(self.env) self.env.seed(env_seed) self.env = ESRolloutWorkerWrapper.wrap(env=self.env, shared_noise=shared_noise, agent_instance_seed=agent_seed)
def test_concepts_and_structures_run_context_overview(): """ Tests snippets in docs/source/concepts_and_structure/run_context_overview.rst. """ # Default overrides for faster tests. Shouldn't change functionality. ac_overrides = {"runner.concurrency": 1} es_overrides = {"algorithm.n_epochs": 1, "algorithm.n_rollouts_per_update": 1} # Training # -------- rc = RunContext( algorithm="a2c", overrides={"env.name": "CartPole-v0", **ac_overrides}, model="vector_obs", critic="template_state", runner="dev", configuration="test" ) rc.train(n_epochs=1) alg_config = A2CAlgorithmConfig( n_epochs=1, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SequentialVectorEnv([lambda: GymMazeEnv("CartPole-v0")]), n_episodes=1, model_selection=None, deterministic=True ) ) rc = RunContext( algorithm=alg_config, overrides={"env.name": "CartPole-v0", **ac_overrides}, model="vector_obs", critic="template_state", runner="dev", configuration="test" ) rc.train(n_epochs=1) rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'), overrides=es_overrides, runner="dev", configuration="test") rc.train(n_epochs=1) policy_composer_config = { '_target_': 'maze.perception.models.policies.ProbabilisticPolicyComposer', 'networks': [{ '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet', 'non_lin': 'torch.nn.Tanh', 'hidden_units': [256, 256] }], "substeps_with_separate_agent_nets": [], "agent_counts_dict": {0: 1} } rc = RunContext( overrides={"model.policy": policy_composer_config, **es_overrides}, runner="dev", configuration="test" ) rc.train(n_epochs=1) env = GymMazeEnv('CartPole-v0') policy_composer = ProbabilisticPolicyComposer( action_spaces_dict=env.action_spaces_dict, observation_spaces_dict=env.observation_spaces_dict, distribution_mapper=DistributionMapper(action_space=env.action_space, distribution_mapper_config={}), networks=[{ '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet', 'non_lin': 'torch.nn.Tanh', 'hidden_units': [222, 222] }], substeps_with_separate_agent_nets=[], agent_counts_dict={0: 1} ) rc = RunContext(overrides={"model.policy": policy_composer, **es_overrides}, runner="dev", configuration="test") rc.train(n_epochs=1) rc = RunContext(algorithm=alg_config, overrides=ac_overrides, runner="dev", configuration="test") rc.train(n_epochs=1) rc.train() # Rollout # ------- obs = env.reset() for i in range(10): action = rc.compute_action(obs) obs, rewards, dones, info = env.step(action) # Evaluation # ---------- env.reset() evaluator = RolloutEvaluator( # Environment has to be have statistics logging capabilities for RolloutEvaluator. eval_env=LogStatsWrapper.wrap(env, logging_prefix="eval"), n_episodes=1, model_selection=None ) evaluator.evaluate(rc.policy)
def _worker(remote, parent_remote, env_fn_wrapper): # switch to non-interactive matplotlib backend matplotlib.use('Agg') parent_remote.close() env: MazeEnv = env_fn_wrapper.var() # enable collection of logging statistics if not isinstance(env, LogStatsWrapper): env = LogStatsWrapper.wrap(env) # discard epoch-level statistics (as stats are shipped to the main process after each episode) env = disable_epoch_level_stats(env) while True: try: cmd, data = remote.recv() if cmd == 'step': observation, reward, env_done, info = env.step(data) actor_done = env.is_actor_done() actor_id = env.actor_id() episode_stats = None if env_done: # save final observation where user can get it, then reset info['terminal_observation'] = observation observation = env.reset() # collect episode stats after the reset episode_stats = env.get_stats( LogStatsLevel.EPISODE).last_stats remote.send((observation, reward, env_done, info, actor_done, actor_id, episode_stats, env.get_env_time())) elif cmd == 'seed': env.seed(data) elif cmd == 'reset': observation = env.reset() actor_done = env.is_actor_done() actor_id = env.actor_id() remote.send((observation, actor_done, actor_id, env.get_stats(LogStatsLevel.EPISODE).last_stats, env.get_env_time())) elif cmd == 'close': remote.close() break elif cmd == 'get_spaces': remote.send((env.observation_spaces_dict, env.action_spaces_dict, env.agent_counts_dict)) elif cmd == 'get_actor_rewards': remote.send(env.get_actor_rewards()) elif cmd == 'env_method': method = getattr(env, data[0]) remote.send(method(*data[1], **data[2])) elif cmd == 'get_attr': remote.send(getattr(env, data)) elif cmd == 'set_attr': remote.send(setattr(env, data[0], data[1])) else: raise NotImplementedError except EOFError: break
def train(n_epochs: int) -> int: """ Trains agent in pure Python. :param n_epochs: Number of epochs to train. :return: 0 if successful. """ # Environment setup # ----------------- env = cartpole_env_factory() # Algorithm setup # --------------- algorithm_config = A2CAlgorithmConfig( n_epochs=5, epoch_length=25, patience=15, critic_burn_in_epochs=0, n_rollout_steps=100, lr=0.0005, gamma=0.98, gae_lambda=1.0, policy_loss_coef=1.0, value_loss_coef=0.5, entropy_coef=0.00025, max_grad_norm=0.0, device='cpu', rollout_evaluator=RolloutEvaluator( eval_env=SequentialVectorEnv([cartpole_env_factory]), n_episodes=1, model_selection=None, deterministic=True ) ) # Custom model setup # ------------------ # Policy customization # ^^^^^^^^^^^^^^^^^^^^ # Policy network. policy_net = CartpolePolicyNet( obs_shapes={'observation': env.observation_space.spaces['observation'].shape}, action_logit_shapes={'action': (env.action_space.spaces['action'].n,)} ) policy_networks = [policy_net] # Policy distribution. distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={}) # Policy composer. policy_composer = ProbabilisticPolicyComposer( action_spaces_dict=env.action_spaces_dict, observation_spaces_dict=env.observation_spaces_dict, # Derive distribution from environment's action space. distribution_mapper=distribution_mapper, networks=policy_networks, # We have only one agent and network, thus this is an empty list. substeps_with_separate_agent_nets=[], # We have only one step and one agent. agent_counts_dict={0: 1} ) # Critic customization # ^^^^^^^^^^^^^^^^^^^^ # Value networks. value_networks = { 0: TorchModelBlock( in_keys='observation', out_keys='value', in_shapes=env.observation_space.spaces['observation'].shape, in_num_dims=[2], out_num_dims=2, net=CartpoleValueNet({'observation': env.observation_space.spaces['observation'].shape}) ) } # Critic composer. critic_composer = SharedStateCriticComposer( observation_spaces_dict=env.observation_spaces_dict, agent_counts_dict={0: 1}, networks=value_networks, stack_observations=True ) # Training # ^^^^^^^^ rc = run_context.RunContext( env=cartpole_env_factory, algorithm=algorithm_config, policy=policy_composer, critic=critic_composer, runner="dev" ) rc.train(n_epochs=n_epochs) # Distributed training # ^^^^^^^^^^^^^^^^^^^^ algorithm_config.rollout_evaluator.eval_env = SubprocVectorEnv([cartpole_env_factory]) rc = run_context.RunContext( env=cartpole_env_factory, algorithm=algorithm_config, policy=policy_composer, critic=critic_composer, runner="local" ) rc.train(n_epochs=n_epochs) # Evaluation # ^^^^^^^^^^ print("-----------------") evaluator = RolloutEvaluator( eval_env=LogStatsWrapper.wrap(cartpole_env_factory(), logging_prefix="eval"), n_episodes=1, model_selection=None ) evaluator.evaluate(rc.policy) return 0