def create_agent(environment, obs_stacker, agent_type='DQN'): """Creates the Hanabi agent. Args: environment: The environment. obs_stacker: Observation stacker object. agent_type: str, type of agent to construct. Returns: An agent for playing Hanabi. Raises: ValueError: if an unknown agent type is requested. """ if agent_type == 'DQN': return dqn_agent.DQNAgent( observation_size=obs_stacker.observation_size(), num_actions=environment.num_moves(), num_players=environment.players) elif agent_type == 'Rainbow': return rainbow_agent.RainbowAgent( observation_size=obs_stacker.observation_size(), num_actions=environment.num_moves(), num_players=environment.players) else: raise ValueError( 'Expected valid agent_type, got {}'.format(agent_type))
import gym import time import dqn_agent import torch env = gym.make('CartPole-v1') observation = env.reset() observation = torch.from_numpy(observation) rand_tensor = torch.rand(observation.shape) print("observation: ", observation) print("random tensor: ", rand_tensor) agent = dqn_agent.DQNAgent(observation.shape, env.action_space.n) print(agent.model.forward(observation.float()))
def run_environment(args: argparse.Namespace, device: str = 'cpu', logger=None): # ===================================================== # Initialize environment and pre-processing screen_size = 84 raw_env = gym.make(args.env_name) raw_env.seed(args.seed) # reproducibility settings torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True environment = atari_lib.AtariPreprocessing(raw_env, frame_skip=args.frame_skips, terminal_on_life_loss=True, screen_size=screen_size) num_actions = environment.action_space().n observation_shape = (1, screen_size, screen_size) # ===================================================== # Initialize agent # TODO: update this to use the vq-v2_dqn_agent agent = dqn_agent.DQNAgent( num_actions=num_actions, observation_shape=observation_shape, observation_dtype=torch.uint8, history_size=args.history_size, gamma=args.discount_factor, min_replay_history=args.min_replay_history, update_period=args.update_period, target_update_period=args.target_update_period, epsilon_start=args.init_exploration, epsilon_final=args.final_exploration, epsilon_decay_period=args.eps_decay_duration, memory_buffer_capacity=args.buffer_capacity, minibatch_size=args.minibatch_size, vqvae_embed_dim=args.vqvae_embed_dim, vqvae_recon_threshold=args.vqvae_recon_threshold, vqvae_n_res_block=args.vqvae_n_res_block, vqvae_sample_prob=args.vqvae_sample_prob, vqvae_latent_loss_weight=args.vqvae_latent_loss_weight, vqvae_freeze_point=args.vqvae_freeze_point, vqvae_buffer_capacity=args.vqvae_buffer_capacity, device=device, summary_writer=None) # TODO implement summary writer # TODO: implement memory buffer location # ===================================================== # Start interacting with environment for episode_idx in range(args.num_episode): agent.begin_episode() observation = environment.reset() cumulative_reward = 0.0 steps = 0 while True: action = agent.step(observation) observation, reward, done, info = environment.step(action) # TODO: see if reward is clipped agent.store_transition(action, observation, reward, done) # Tracker variables cumulative_reward += reward steps += 1 if done: # ========================================= # Logging stuff # Compute logging variables avg_policy_net_loss = 0.0 if agent.episode_total_policy_loss > 0.0: avg_policy_net_loss = agent.episode_total_policy_loss / \ agent.episode_total_optim_steps # [VQVAE] compute VAE loss avg_vqvae_recon_loss = 0.0 if agent.total_recon_loss > 0.0: avg_vqvae_recon_loss = agent.total_recon_loss / agent.total_recon_attempts # TODO: might be nice to compute the final epsilon per episode # [VQVAE] save images at certain intervals if args.write_img_period > 0 and ( episode_idx + 1) % args.write_img_period == 0: print(f'Writing image (episode {episode_idx})') img_name = f'imgout_embeddim-{args.vqvae_embed_dim}_thresh-{args.vqvae_recon_threshold}_epis-{episode_idx}.png' img_path = f'{args.img_out_path}/{img_name}' write_image(agent, img_path, num_img=16) logtuple = LogTupStruct( episode_idx=episode_idx, steps=steps, buffer_size=len(agent.memory), training_steps=agent.training_steps, returns=cumulative_reward, policy_net_loss=avg_policy_net_loss, vqvae_buffer_size=len( agent.vqvae), #[VQVAE] below all VAVAE logs vqvae_recon_loss=avg_vqvae_recon_loss, raw_buffer_store_count=agent.episode_store_raw_count, vqvae_buffer_store_count=agent.episode_store_vqvae_count) # Write log log_str = '||'.join([str(e) for e in logtuple]) if args.log_path is not None: logger.info(log_str) else: print(log_str) # # ========================================= # Break out of current episode break
def run_environment(args: argparse.Namespace, device: str = 'cpu', logger=None): # ===================================================== # Initialize environment and pre-processing screen_size = 84 raw_env = gym.make(args.env_name) raw_env.seed(args.seed) # reproducibility settings torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True environment = atari_lib.AtariPreprocessing(raw_env, frame_skip=args.frame_skips, terminal_on_life_loss=True, screen_size=screen_size) num_actions = environment.action_space().n observation_shape = (1, screen_size, screen_size) # ===================================================== # Initialize agent agent = dqn_agent.DQNAgent( num_actions=num_actions, observation_shape=observation_shape, observation_dtype=torch.uint8, history_size=args.history_size, gamma=args.discount_factor, min_replay_history=args.min_replay_history, update_period=args.update_period, target_update_period=args.target_update_period, epsilon_start=args.init_exploration, epsilon_final=args.final_exploration, epsilon_decay_period=args.eps_decay_duration, memory_buffer_capacity=args.buffer_capacity, minibatch_size=args.minibatch_size, device=device, summary_writer=None) # TODO implement summary writer # TODO: implement memory buffer location # ===================================================== # Start interacting with environment for episode_idx in range(args.num_episode): agent.begin_episode() observation = environment.reset() cumulative_reward = 0.0 steps = 0 while True: action = agent.step(observation) observation, reward, done, info = environment.step(action) # TODO: see if reward is clipped agent.store_transition(action, observation, reward, done) # Tracker variables cumulative_reward += reward steps += 1 if done: # ========================================= # Logging stuff # Compute logging variables avg_policy_net_loss = 0.0 if agent.episode_total_policy_loss > 0.0: avg_policy_net_loss = agent.episode_total_policy_loss / \ agent.episode_total_optim_steps # TODO: might be nice to compute the final epsilon per episode logtuple = LogTupStruct(episode_idx=episode_idx, steps=steps, buffer_size=len(agent.memory), training_steps=agent.training_steps, returns=cumulative_reward, policy_net_loss=avg_policy_net_loss) # Write log log_str = '||'.join([str(e) for e in logtuple]) if args.log_path is not None: logger.info(log_str) else: print(log_str) # # ========================================= # Break out of current episode break