示例#1
0
def create_agent(environment, obs_stacker, agent_type='DQN'):
    """Creates the Hanabi agent.

  Args:
    environment: The environment.
    obs_stacker: Observation stacker object.
    agent_type: str, type of agent to construct.

  Returns:
    An agent for playing Hanabi.

  Raises:
    ValueError: if an unknown agent type is requested.
  """
    if agent_type == 'DQN':
        return dqn_agent.DQNAgent(
            observation_size=obs_stacker.observation_size(),
            num_actions=environment.num_moves(),
            num_players=environment.players)
    elif agent_type == 'Rainbow':
        return rainbow_agent.RainbowAgent(
            observation_size=obs_stacker.observation_size(),
            num_actions=environment.num_moves(),
            num_players=environment.players)
    else:
        raise ValueError(
            'Expected valid agent_type, got {}'.format(agent_type))
示例#2
0
import gym
import time
import dqn_agent
import torch

env = gym.make('CartPole-v1')
observation = env.reset()

observation = torch.from_numpy(observation)
rand_tensor = torch.rand(observation.shape)
print("observation: ", observation)
print("random tensor: ", rand_tensor)

agent = dqn_agent.DQNAgent(observation.shape, env.action_space.n)

print(agent.model.forward(observation.float()))
示例#3
0
def run_environment(args: argparse.Namespace,
                    device: str = 'cpu',
                    logger=None):

    # =====================================================
    # Initialize environment and pre-processing

    screen_size = 84
    raw_env = gym.make(args.env_name)

    raw_env.seed(args.seed)  # reproducibility settings
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    environment = atari_lib.AtariPreprocessing(raw_env,
                                               frame_skip=args.frame_skips,
                                               terminal_on_life_loss=True,
                                               screen_size=screen_size)

    num_actions = environment.action_space().n
    observation_shape = (1, screen_size, screen_size)

    # =====================================================
    # Initialize agent
    # TODO: update this to use the vq-v2_dqn_agent
    agent = dqn_agent.DQNAgent(
        num_actions=num_actions,
        observation_shape=observation_shape,
        observation_dtype=torch.uint8,
        history_size=args.history_size,
        gamma=args.discount_factor,
        min_replay_history=args.min_replay_history,
        update_period=args.update_period,
        target_update_period=args.target_update_period,
        epsilon_start=args.init_exploration,
        epsilon_final=args.final_exploration,
        epsilon_decay_period=args.eps_decay_duration,
        memory_buffer_capacity=args.buffer_capacity,
        minibatch_size=args.minibatch_size,
        vqvae_embed_dim=args.vqvae_embed_dim,
        vqvae_recon_threshold=args.vqvae_recon_threshold,
        vqvae_n_res_block=args.vqvae_n_res_block,
        vqvae_sample_prob=args.vqvae_sample_prob,
        vqvae_latent_loss_weight=args.vqvae_latent_loss_weight,
        vqvae_freeze_point=args.vqvae_freeze_point,
        vqvae_buffer_capacity=args.vqvae_buffer_capacity,
        device=device,
        summary_writer=None)  # TODO implement summary writer
    # TODO: implement memory buffer location

    # =====================================================
    # Start interacting with environment
    for episode_idx in range(args.num_episode):

        agent.begin_episode()
        observation = environment.reset()

        cumulative_reward = 0.0
        steps = 0

        while True:

            action = agent.step(observation)
            observation, reward, done, info = environment.step(action)
            # TODO: see if reward is clipped

            agent.store_transition(action, observation, reward, done)

            # Tracker variables
            cumulative_reward += reward
            steps += 1

            if done:
                # =========================================
                # Logging stuff
                # Compute logging variables
                avg_policy_net_loss = 0.0
                if agent.episode_total_policy_loss > 0.0:
                    avg_policy_net_loss = agent.episode_total_policy_loss / \
                                          agent.episode_total_optim_steps
                # [VQVAE] compute VAE loss
                avg_vqvae_recon_loss = 0.0
                if agent.total_recon_loss > 0.0:
                    avg_vqvae_recon_loss = agent.total_recon_loss / agent.total_recon_attempts
                # TODO: might be nice to compute the final epsilon per episode

                # [VQVAE] save images at certain intervals
                if args.write_img_period > 0 and (
                        episode_idx + 1) % args.write_img_period == 0:
                    print(f'Writing image (episode {episode_idx})')
                    img_name = f'imgout_embeddim-{args.vqvae_embed_dim}_thresh-{args.vqvae_recon_threshold}_epis-{episode_idx}.png'
                    img_path = f'{args.img_out_path}/{img_name}'
                    write_image(agent, img_path, num_img=16)

                logtuple = LogTupStruct(
                    episode_idx=episode_idx,
                    steps=steps,
                    buffer_size=len(agent.memory),
                    training_steps=agent.training_steps,
                    returns=cumulative_reward,
                    policy_net_loss=avg_policy_net_loss,
                    vqvae_buffer_size=len(
                        agent.vqvae),  #[VQVAE] below all VAVAE logs
                    vqvae_recon_loss=avg_vqvae_recon_loss,
                    raw_buffer_store_count=agent.episode_store_raw_count,
                    vqvae_buffer_store_count=agent.episode_store_vqvae_count)

                # Write log
                log_str = '||'.join([str(e) for e in logtuple])
                if args.log_path is not None:
                    logger.info(log_str)
                else:
                    print(log_str)

                # # =========================================
                # Break out of current episode
                break
示例#4
0
def run_environment(args: argparse.Namespace,
                    device: str = 'cpu',
                    logger=None):

    # =====================================================
    # Initialize environment and pre-processing

    screen_size = 84
    raw_env = gym.make(args.env_name)

    raw_env.seed(args.seed)  # reproducibility settings
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    environment = atari_lib.AtariPreprocessing(raw_env,
                                               frame_skip=args.frame_skips,
                                               terminal_on_life_loss=True,
                                               screen_size=screen_size)

    num_actions = environment.action_space().n
    observation_shape = (1, screen_size, screen_size)

    # =====================================================
    # Initialize agent
    agent = dqn_agent.DQNAgent(
        num_actions=num_actions,
        observation_shape=observation_shape,
        observation_dtype=torch.uint8,
        history_size=args.history_size,
        gamma=args.discount_factor,
        min_replay_history=args.min_replay_history,
        update_period=args.update_period,
        target_update_period=args.target_update_period,
        epsilon_start=args.init_exploration,
        epsilon_final=args.final_exploration,
        epsilon_decay_period=args.eps_decay_duration,
        memory_buffer_capacity=args.buffer_capacity,
        minibatch_size=args.minibatch_size,
        device=device,
        summary_writer=None)  # TODO implement summary writer
    # TODO: implement memory buffer location

    # =====================================================
    # Start interacting with environment
    for episode_idx in range(args.num_episode):

        agent.begin_episode()
        observation = environment.reset()

        cumulative_reward = 0.0
        steps = 0

        while True:

            action = agent.step(observation)
            observation, reward, done, info = environment.step(action)
            # TODO: see if reward is clipped

            agent.store_transition(action, observation, reward, done)

            # Tracker variables
            cumulative_reward += reward
            steps += 1

            if done:
                # =========================================
                # Logging stuff
                # Compute logging variables
                avg_policy_net_loss = 0.0
                if agent.episode_total_policy_loss > 0.0:
                    avg_policy_net_loss = agent.episode_total_policy_loss / \
                                          agent.episode_total_optim_steps
                # TODO: might be nice to compute the final epsilon per episode

                logtuple = LogTupStruct(episode_idx=episode_idx,
                                        steps=steps,
                                        buffer_size=len(agent.memory),
                                        training_steps=agent.training_steps,
                                        returns=cumulative_reward,
                                        policy_net_loss=avg_policy_net_loss)

                # Write log
                log_str = '||'.join([str(e) for e in logtuple])
                if args.log_path is not None:
                    logger.info(log_str)
                else:
                    print(log_str)

                # # =========================================
                # Break out of current episode
                break