Ejemplo n.º 1
0
    def __init__(self,
                 agent,
                 target_agent,
                 args,
                 count_model=None,
                 buffer=None):
        self.args = args
        self.agent = agent
        self.target_agent = target_agent

        # self.parameters = self.agent.parameters()
        self.agent_parameters = list(self.agent.parameters())
        if args.atari_rms:
            self.optimiser = torch.optim.RMSprop(params=self.agent_parameters,
                                                 lr=self.args.lr,
                                                 alpha=0.95,
                                                 eps=0.00001,
                                                 centered=True)
        else:
            self.optimiser = torch.optim.RMSprop(params=self.agent_parameters,
                                                 lr=self.args.lr)

        self.logger = logging.getLogger("DQNTrainer")
        self.stats = get_stats()

        self.count_model = count_model

        self.nstep = args.n_step > 1
        self.goal_samples = 0

        self.buffer = buffer
Ejemplo n.º 2
0
    def __init__(self, config, device):
        self.config = config
        self.args = config  # V. Lazy coding

        num_actions = config.num_actions
        if not self.config.count_state_action:
            raise Exception

        self.num_actions = num_actions

        self.net = dora_specifier(config.dora_name)(config).to(device)
        self.target_net = dora_specifier(config.dora_name)(config).to(device)
        self.target_net.load_state_dict(self.net.state_dict())

        self.stats = get_stats()
        self.reward_directly = True

        self.train_times = config.batch_size
        self.states = None
        self.states_idx = 0

        self.device = device
        self.agent_parameters = self.net.parameters()
        self.optimiser = torch.optim.RMSprop(params=self.agent_parameters,
                                             lr=self.config.lr)
Ejemplo n.º 3
0
    def __init__(self, args):

        self.epsilon_start = args.epsilon_start
        self.epsilon_finish = args.epsilon_finish
        self.epsilon_time_length = args.epsilon_time_length

        self.num_actions = args.num_actions
        self.stats = get_stats()
        self.logger = logging.getLogger("EpsGreedy")
Ejemplo n.º 4
0
    def __init__(self, config):
        # self.hash_counts = [HashingBonusEvaluator(dim_key=config.count_size, obs_processed_flat_dim=np.prod(config.state_shape)) for _ in range(config.num_actions)]
        self.hash_counts = HashingBonusEvaluator(dim_key=config.count_size, obs_processed_flat_dim=np.prod(config.state_shape), actions=config.num_actions)
        self.hash_state_count = HashingBonusEvaluator(dim_key=config.count_size, obs_processed_flat_dim=np.prod(config.state_shape), actions=1)
        self.config = config

        self.flatten = True if len(config.state_shape) > 1 else False

        self.stats = get_stats()
Ejemplo n.º 5
0
    def __init__(self, args):

        self.epsilon_start = args.epsilon_start
        self.epsilon_finish = args.epsilon_finish
        self.epsilon_time_length = args.epsilon_time_length

        self.num_actions = args.num_actions
        self.stats = get_stats()
        self.logger = logging.getLogger("BSPAction")

        self.config = args

        self.current_k = 0
Ejemplo n.º 6
0
    def __init__(self, count_model, args):

        self.epsilon_start = args.epsilon_start
        self.epsilon_finish = args.epsilon_finish
        self.epsilon_time_length = args.epsilon_time_length

        self.num_actions = args.num_actions
        self.stats = get_stats()
        self.logger = logging.getLogger("OptimisticAction")

        self.count_model = count_model
        self.m = args.optim_m
        self.tau = args.optim_action_tau

        self.config = args
Ejemplo n.º 7
0
    def __init__(self, config):
        self.config = config

        self.num_actions = self.config.num_actions

        self.state_action_counters = [
            defaultdict(lambda: 0) for _ in range(self.num_actions)
        ]

        self.target_shape = (config.atari_target_x, config.atari_target_y)
        self.max_pix_value = 255

        self.stats = get_stats()

        self.hash_vector = None
Ejemplo n.º 8
0
    def __init__(self, count_model, args):

        self.epsilon_start = args.epsilon_start
        self.epsilon_finish = args.epsilon_finish
        self.epsilon_time_length = args.epsilon_time_length

        self.num_actions = args.num_actions
        self.stats = get_stats()
        self.logger = logging.getLogger("CountBonus")

        self.min_q = 1000
        self.max_q = -1000

        self.count_model = count_model
        self.args = args
Ejemplo n.º 9
0
    def __init__(self, game):
        self.mont_env = gym.make("{}NoFrameskip-v4".format(game))
        # self.mario_env = modewrapper(self.mario_env)
        self.max_timesteps = 4500
        self.mont_env = make_atari(self.mont_env, max_episode_steps=self.max_timesteps)

        self.steps = 0

        self._seed = 56
        self.observation_space = self.mont_env.observation_space
        self.action_space = self.mont_env.action_space

        self.stats = get_stats()

        self.obs_dtype = np.uint8
        self.obs_scaling = 1/255.0
Ejemplo n.º 10
0
    def __init__(self):
        self.mont_env = gym.make("MontezumaRevengeNoFrameskip-v4")
        # self.mario_env = modewrapper(self.mario_env)
        self.max_timesteps = 4500
        self.mont_env = make_montezuma(self.mont_env, max_episode_steps=self.max_timesteps)

        self.steps = 0

        self._seed = 56
        self.observation_space = self.mont_env.observation_space
        self.action_space = self.mont_env.action_space

        self.stats = get_stats()

        self.obs_dtype = np.uint8
        self.obs_scaling = 1/255.0
Ejemplo n.º 11
0
    def __init__(self, config, device):
        self.config = config

        # assert self.config.count_state_action # For now
        num_actions = config.num_actions
        if not self.config.count_state_action:
            num_actions = 1

        self.num_actions = num_actions

        # Maybe the predictors should all start the same and have the same target
        self.predictors = [rnd_specifier.get_pred(config.rnd_net_name)(config).to(device) for _ in range(num_actions)]
        self.targets = [rnd_specifier.get_target(config.rnd_net_name)(config).to(device) for _ in range(num_actions)]

        self.states = [None for _ in range(num_actions)]
        self.states_idx = [0 for _ in range(num_actions)]

        if self.config.rnd_same_starts:
            p_dict = self.predictors[0].state_dict()
            t_dict = self.targets[0].state_dict()
            for p in self.predictors:
                p.load_state_dict(p_dict)
            for t in self.targets:
                t.load_state_dict(t_dict)

        if self.config.count_state_only_rewards:
            self.state_predictor = rnd_specifier.get_pred(config.rnd_net_name)(config).to(device)
            self.state_target = rnd_specifier.get_target(config.rnd_net_name)(config).to(device)
            if self.config.rnd_same_starts:
                self.state_predictor.load_state_dict(p_dict)
                self.state_target.load_state_dict(t_dict)
            self.state_states = None
            self.state_states_idx = 0

        self.stats = get_stats()
        self.reward_directly = True

        # Training stuff
        self.net_parameters = []
        for n in self.predictors:
            self.net_parameters += n.parameters()
        if self.config.count_state_only_rewards:
            self.net_parameters += self.state_predictor.parameters()
        self.optimiser = torch.optim.RMSprop(params=self.net_parameters, lr=self.config.lr)

        self.train_times = self.config.rnd_train_times if self.config.rnd_train_times > 0 else self.config.rnd_batch_size
Ejemplo n.º 12
0
def main(_config, _run):
    config = convert(_config)
    _id = _run._id

    # Logging stuff
    logger = logging.getLogger("Main")
    if config.mongo:
        logging.disable(logging.WARNING)
    configure_stats_logging(
        str(_id) + "_" + config.name,
        log_interval=config.log_interval,
        sacred_info=_run.info,
        use_tb=config.tb,
    )
    stats = get_stats()

    logger.critical("ID: {}".format(_id))
    # Update config with environment specific information
    env = gym.make(config.env)
    num_actions = env.action_space.n
    config = config._replace(num_actions=num_actions)
    state_shape = env.observation_space.shape
    config = config._replace(state_shape=state_shape)
    # Wrap env
    env = EnvWrapper(env, debug=True, args=config)

    # Log the config
    config_str = "Config:\n\n"
    for k, v in sorted(config._asdict().items()):
        config_str += "     {}: {}\n".format(k, v)
    logger.critical(config_str)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.critical("Device: {}".format(device.type))

    # Make agent and target agent
    agent = get_model(config.agent)(config)
    target_agent = get_model(config.agent)(config)
    target_agent.load_state_dict(agent.state_dict())
    agent.to(device)
    target_agent.to(device)

    # Pseudocount stuff
    count_model = None
    if config.count_rewards:
        if config.atari_count:
            count_model = AtariCount(config)
        elif config.rnd_net_count:
            # assert config.count_state_only_rewards
            count_model = RndNetworkDistill(config, device)
        elif config.dora_count:
            count_model = DoraCount(config, device)
        else:
            count_model = PseudoCount(config)

    # Make action selector
    action_selector = None
    if config.action_selector == "eps_greedy":
        action_selector = eps_greedy.EpsGreedy(config)
    elif config.action_selector == "optimistic_action":
        action_selector = optimistic_action.OptimisticAction(
            count_model, config)
    elif config.action_selector == "bsp":
        action_selector = bsp_action.BSPAction(config)
    else:
        raise Exception("{} is not an Action Selector!".format(
            config.action_selector))

    # Make replay buffer
    # Check if the obs dtype of the environment is an int
    obs_dtype = getattr(env.wrapped_env, "obs_dtype", np.float32)
    obs_scaling = getattr(env.wrapped_env, "obs_scaling", 1)
    replay_buffer = ReplayBuffer(size=config.buffer_size,
                                 frame_history_len=config.past_frames_input,
                                 obs_dtype=obs_dtype,
                                 obs_scaling=obs_scaling,
                                 args=config)

    if config.dora_count:
        dora_buffer = ReplayBuffer(size=config.batch_size * 4,
                                   frame_history_len=config.past_frames_input,
                                   obs_dtype=obs_dtype,
                                   obs_scaling=obs_scaling,
                                   args=config)

    # Make trainer
    trainer = None
    if config.trainer == "DQN":
        trainer = DQNTrainer(agent=agent,
                             target_agent=target_agent,
                             args=config,
                             count_model=count_model,
                             buffer=replay_buffer)
    else:
        raise Exception
    testing_buffer = ReplayBuffer(size=(config.past_frames_input + 1),
                                  frame_history_len=config.past_frames_input,
                                  args=config)

    # Testing stuff
    testing_env = EnvWrapper(env=gym.make(config.env), debug=True, args=config)
    if config.test_augmented:
        assert config.action_selector == "optimistic_action"

    # Player Positions
    positions = set()
    action_positions = set()

    T = 0
    start_time = time.time()
    last_time = start_time

    # Lots of code duplication :(
    logging.critical("Filling buffer with {:,} random experiences.".format(
        config.buffer_burn_in))
    state = env.reset()
    assert config.buffer_burn_in == 0
    for t in range(config.buffer_burn_in):
        buffer_idx = replay_buffer.store_frame(state)
        stacked_states = replay_buffer.encode_recent_observation()
        tensor_state = torch.tensor(stacked_states, device=device).unsqueeze(0)
        action = np.random.randint(config.num_actions)
        next_state, reward, terminated, info = env.step(action)
        terminal_to_store = terminated
        if "Steps_Termination" in info and info["Steps_Termination"]:
            terminal_to_store = False

        intrinsic_reward = 0
        pseudo_count = 0
        if config.count_rewards:
            pseudo_count = count_model.visit(tensor_state, action)
            if getattr(count_model, "reward_directly", False):
                intrinsic_reward = pseudo_count
            else:
                count_bonus = config.count_beta / sqrt(pseudo_count)
                intrinsic_reward = count_bonus

        replay_buffer.store_effect(buffer_idx, action,
                                   reward - config.reward_baseline,
                                   intrinsic_reward, terminal_to_store,
                                   pseudo_count)
        state = next_state
        if terminated:
            state = env.reset()
            logger.warning("Random action burn in t: {:,}".format(t))

    state = env.reset()
    episode = 0
    episode_reward = 0
    intrinsic_episode_reward = 0
    episode_length = 0
    env_positive_reward = 0
    max_episode_reward = 0
    if config.bsp:
        bsp_k = np.random.randint(config.bsp_k)
        action_selector.update_k(bsp_k)

    logging.critical("Beginning training.")

    while T < config.t_max:

        # Store the current state
        buffer_idx = replay_buffer.store_frame(state)
        if config.dora_count:
            dora_idx = dora_buffer.store_frame(state)

        # Get the stacked input vector
        stacked_states = replay_buffer.encode_recent_observation()

        # Get output from agent
        with torch.no_grad():
            tensor_state = torch.tensor(stacked_states,
                                        device=device).unsqueeze(0)
            agent_output = agent(tensor_state)
            # agent_output = agent(torch.Tensor(stacked_states).unsqueeze(0))

        # Select action
        action, action_info = action_selector.select_actions(
            agent_output, T, info={"state": tensor_state})

        # Take an environment step
        next_state, reward, terminated, info = env.step(action)
        T += 1
        stats.update_t(T)
        episode_reward += reward
        episode_length += 1
        terminal_to_store = terminated
        if "Steps_Termination" in info and info["Steps_Termination"]:
            logger.warning("Terminating because of episode limit")
            terminal_to_store = False

        # Log if a positive reward was ever received from environment. ~Finding goal
        if reward > 0.1:
            env_positive_reward = 1
        stats.update_stats("Positive_Reward", env_positive_reward)

        # Calculate count based intrinsic motivation
        intrinsic_reward = 0
        pseudo_count = 0
        if config.count_rewards:
            pseudo_count = count_model.visit(tensor_state, action)
            if getattr(count_model, "reward_directly", False):
                # The count-model is giving us the intrinsic reward directly
                intrinsic_reward = pseudo_count[0]
            else:
                # Count-model is giving us the pseudo-count
                count_bonus = config.count_beta / sqrt(pseudo_count)
                intrinsic_reward = count_bonus
            intrinsic_episode_reward += intrinsic_reward

        # Render training
        if config.render_train_env:
            debug_info = {}
            debug_info.update(action_info)
            env.render(debug_info=debug_info)

        # Add what happened to the buffer
        replay_buffer.store_effect(buffer_idx, action,
                                   reward - config.reward_baseline,
                                   intrinsic_reward, terminal_to_store,
                                   pseudo_count)
        if config.dora_count:
            dora_buffer.store_effect(dora_idx, action,
                                     reward - config.reward_baseline,
                                     intrinsic_reward, terminal_to_store,
                                     pseudo_count)

        # Update state
        state = next_state

        # If terminated
        if terminated:
            # If we terminated due to episode limit, we need to add the current state in
            if "Steps_Termination" in info and info["Steps_Termination"]:
                buffer_idx = replay_buffer.store_frame(state)
                replay_buffer.store_effect(buffer_idx,
                                           0,
                                           0,
                                           0,
                                           True,
                                           0,
                                           dont_sample=True)
                if config.dora_count:
                    dora_idx = dora_buffer.store_frame(state)
                    dora_buffer.store_effect(dora_idx,
                                             0,
                                             0,
                                             0,
                                             True,
                                             0,
                                             dont_sample=True)

            logger.warning("T: {:,}, Episode Reward: {:.2f}".format(
                T, episode_reward))
            state = env.reset()
            max_episode_reward = max(max_episode_reward, episode_reward)
            stats.update_stats("Episode Reward", episode_reward)
            stats.update_stats("Max Episode Reward", max_episode_reward)
            stats.update_stats("Episode Length", episode_length)
            stats.update_stats("Intrin Eps Reward", intrinsic_episode_reward)
            episode_reward = 0
            episode_length = 0
            intrinsic_episode_reward = 0
            episode += 1
            stats.update_stats("Episode", episode)
            if config.bsp:
                bsp_k = np.random.randint(config.bsp_k)
                action_selector.update_k(bsp_k)

        # Train if possible
        for _ in range(config.training_iters):
            sampled_batch = None

            if T % config.update_freq != 0:
                # Only train every update_freq timesteps
                continue
            if replay_buffer.can_sample(config.batch_size):
                sampled_batch = replay_buffer.sample(config.batch_size,
                                                     nstep=config.n_step)

            if sampled_batch is not None:
                trainer.train(sampled_batch)

            if config.dora_count:
                if dora_buffer.can_sample(config.batch_size):
                    sampled_batch = replay_buffer.sample(config.batch_size,
                                                         nstep=config.n_step)
                if sampled_batch is not None:
                    count_model.train(sampled_batch)

        # Update target networks if necessary
        if T % config.target_update_interval == 0:
            trainer.update_target_agent()
            if config.dora_count:
                count_model.update_target_agent()

        # Logging
        if config.bsp:
            agent_output = agent_output[:, :, bsp_k]
        q_vals_numpy = agent_output.detach().cpu()[0].numpy()
        if num_actions < 20:
            for action_id in range(config.num_actions):
                stats.update_stats("Q-Value_{}".format(action_id),
                                   q_vals_numpy[action_id])
        else:
            stats.update_stats("Q-Value_Mean", np.mean(q_vals_numpy))
        player_pos = env.log_visitation()
        positions.add(player_pos)
        action_positions.add((player_pos, action))
        stats.update_stats("States Visited", len(positions))
        stats.update_stats("State_Actions Visited", len(action_positions))
        stats.update_stats("Player Position", player_pos)
        # Log all env stats returned
        for k, v in info.items():
            if k != "Steps_Termination":
                stats.update_stats(k, v)

        if config.save_count_gifs > 0 and T % config.save_count_gifs == 0:
            if count_model is not None:
                state_action_counts, count_nums = env.count_state_action_space(
                    count_model)
                if state_action_counts is not None:
                    save_image(state_action_counts,
                               image_name="SA_Counts__{}_Size__{}_T".format(
                                   config.count_size, T),
                               direc_name="State_Action_Counts")
                    save_sa_count_vals(count_nums,
                                       name="SA_PCounts__{}_Size__{}_T".format(
                                           config.count_size, T),
                                       direc_name="Sa_Count_Estimates")

                actual_counts = env.state_counts()
                if actual_counts is not None:
                    save_actual_counts(actual_counts,
                                       name="Counts__{}_T".format(T),
                                       direc_name="Actual_Counts")

                q_val_img, q_vals = env.q_value_estimates(count_model, agent)
                if q_val_img is not None:
                    save_image(q_val_img,
                               image_name="Q_Vals__{}_Size__{}_T".format(
                                   config.count_size, T),
                               direc_name="Q_Value_Estimates")
                if q_vals is not None:
                    save_q_vals(q_vals,
                                name="Q_Vals__{}_Size__{}_T".format(
                                    config.count_size, T),
                                direc_name="Q_Value_Estimates")

        # Testing
        with torch.no_grad():
            if T % config.testing_interval == 0:

                prefixes = [""]
                if config.test_augmented:
                    prefixes += ["Aug_"]

                for prefix in prefixes:
                    total_test_reward = 0
                    total_test_length = 0
                    for _ in range(config.test_episodes):
                        test_episode_reward = 0
                        test_episode_length = 0
                        test_state = testing_env.reset()
                        test_env_terminated = False

                        while not test_env_terminated:
                            test_buffer_idx = testing_buffer.store_frame(
                                test_state)
                            stacked_test_states = testing_buffer.encode_recent_observation(
                            )
                            test_tensor_state = torch.tensor(
                                stacked_test_states,
                                device=device).unsqueeze(0)
                            testing_agent_output = agent(test_tensor_state)

                            if prefix == "Aug_" or config.bsp:
                                test_action, _ = action_selector.select_actions(
                                    testing_agent_output,
                                    T,
                                    info={"state": test_tensor_state},
                                    testing=True)
                            else:
                                test_action = get_test_action(
                                    testing_agent_output, config)

                            next_test_state, test_reward, test_env_terminated, _ = testing_env.step(
                                test_action)
                            if config.render_test_env:
                                testing_env.render()

                            test_episode_length += 1
                            test_episode_reward += test_reward

                            testing_buffer.store_effect(
                                test_buffer_idx, test_action, test_reward, 0,
                                test_env_terminated, 0)

                            test_state = next_test_state

                        total_test_length += test_episode_length
                        total_test_reward += test_episode_reward

                    mean_test_reward = total_test_reward / config.test_episodes
                    mean_test_length = total_test_length / config.test_episodes

                    logger.error(
                        "{}Testing -- T: {:,}/{:,}, Test Reward: {:.2f}, Test Length: {:,}"
                        .format(prefix, T, config.t_max, mean_test_reward,
                                mean_test_length))

                    stats.update_stats("{}Test Reward".format(prefix),
                                       mean_test_reward,
                                       always_log=True)
                    stats.update_stats("{}Test Episode Length".format(prefix),
                                       mean_test_length,
                                       always_log=True)

                logger.error("Estimated time left: {}. Time passed: {}".format(
                    time_left(last_time, T - config.testing_interval,
                              T, config.t_max),
                    time_str(time.time() - start_time)))
                last_time = time.time()

        if T % (config.log_interval * 4) == 0:
            stats.print_stats()

    logger.critical("Closing envs")
    env.close()
    testing_env.close()

    logger.critical("Finished training.")

    if client is not None:
        logger.critical("Attempting to close pymongo client")
        client.close()
        logger.critical("Pymongo client closed")

    logger.critical("Exiting")