class ASTEnv(gym.Env):
    # class ASTEnv(GarageEnv):
    def __init__(self,
                 open_loop=True,
                 blackbox_sim_state=True,
                 fixed_init_state=False,
                 s_0=None,
                 simulator=None,
                 reward_function=None,
                 spaces=None):
        # Constant hyper-params -- set by user
        self.open_loop = open_loop
        self.blackbox_sim_state = blackbox_sim_state  # is this redundant?
        self.spaces = spaces
        # These are set by reset, not the user
        self._done = False
        self._reward = 0.0
        self._info = []
        self._step = 0
        self._action = None
        self._actions = []
        self._first_step = True
        self.reward_range = (-float('inf'), float('inf'))
        self.metadata = None
        self._cum_reward = 0.0

        if s_0 is None:
            self._init_state = self.observation_space.sample()
        else:
            self._init_state = s_0
        self._fixed_init_state = fixed_init_state
        self.simulator = simulator
        if self.simulator is None:
            self.simulator = ExampleAVSimulator()
        self.reward_function = reward_function
        if self.reward_function is None:
            self.reward_function = ExampleAVReward()

        if hasattr(self.simulator, "vec_env_executor") and callable(
                getattr(self.simulator, "vec_env_executor")):
            self.vectorized = True
        else:
            self.vectorized = False
        # super().__init__(self)

    def step(self, action):
        """
        Run one timestep of the environment's dynamics. When end of episode
        is reached, reset() should be called to reset the environment's internal state.
        Input
        -----
        action : an action provided by the environment
        Outputs
        -------
        (observation, reward, done, info)
        observation : agent's observation of the current environment
        reward [Float] : amount of reward due to the previous action
        done : a boolean, indicating whether the episode has ended
        info : a dictionary containing other diagnostic information from the previous action
        """
        self._action = action
        self._actions.append(action)
        action_return = self._action
        # Update simulation step
        obs = self.simulator.step(self._action)
        if (obs is None) or (self.open_loop is
                             True) or (self.blackbox_sim_state):
            obs = np.array(self._init_state)
        # if self.simulator.is_goal():
        if self.simulator.is_terminal() or self.simulator.is_goal():
            self._done = True
        # Calculate the reward for this step
        self._reward = self.reward_function.give_reward(
            action=self._action, info=self.simulator.get_reward_info())
        self._cum_reward += self._reward
        # Update instance attributes
        self._step = self._step + 1
        self._simulator_state = self.simulator.clone_state()
        self._env_state = np.concatenate(
            (self._simulator_state, np.array([self._cum_reward
                                              ]), np.array([self._step])),
            axis=0)
        if self._done:
            self.simulator.simulate(self._actions, self._init_state)
            if not (self.simulator.is_goal() or self.simulator.is_terminal()):
                pdb.set_trace()
        return Step(
            observation=obs,
            reward=self._reward,
            done=self._done,
            cache=self._info,
            actions=action_return,
            # step = self._step -1,
            # real_actions=self._action,
            state=self._env_state,
            # root_action=self.root_action,
            is_terminal=self.simulator.is_terminal(),
            is_goal=self.simulator.is_goal())

    def simulate(self, actions):
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        self.simulator.simulate(actions, self._init_state)

    def reset(self):
        """
        Resets the state of the environment, returning an initial observation.
        Outputs
        -------
        observation : the initial observation of the space. (Initial reward is assumed to be 0.)
        """
        self._actions = []
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        self._done = False
        self._reward = 0.0
        self._cum_reward = 0.0
        self._info = []
        self._action = None
        self._actions = []
        self._first_step = True
        self._step = 0

        obs = np.array(self.simulator.reset(self._init_state))
        if not self._fixed_init_state:
            obs = np.concatenate((obs, np.array(self._init_state)), axis=0)

        return obs

    @property
    def action_space(self):
        """
        Returns a Space object
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.action_space)
            return self.simulator.action_space
        else:
            return self.spaces.action_space

    @property
    def observation_space(self):
        """
        Returns a Space object
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.observation_space)
            return self.simulator.observation_space
        else:
            return self.spaces.observation_space

    def get_cache_list(self):
        return self._info

    def log(self):
        self.simulator.log()

    def render(self):
        if hasattr(self.simulator, "render") and callable(
                getattr(self.simulator, "render")):
            return self.simulator.render()
        else:
            return None

    def close(self):
        if hasattr(self.simulator, "close") and callable(
                getattr(self.simulator, "close")):
            self.simulator.close()
        else:
            return None

    def vec_env_executor(self, n_envs, max_path_length):
        return self.simulator.vec_env_executor(n_envs, max_path_length,
                                               self.reward_function,
                                               self._fixed_init_state,
                                               self._init_state,
                                               self.open_loop)

    def log_diagnostics(self, paths):
        pass

    @cached_property
    def spec(self):
        """
        Returns an EnvSpec.

        Returns:
            spec (garage.envs.EnvSpec)
        """
        return EnvSpec(observation_space=self.observation_space,
                       action_space=self.action_space)
    def run_task(snapshot_config, *_):

        seed = 0
        # top_k = 10
        np.random.seed(seed)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            with tf.variable_scope('AST', reuse=tf.AUTO_REUSE):

                with LocalTFRunner(snapshot_config=snapshot_config,
                                   max_cpus=4,
                                   sess=sess) as local_runner:

                    # Instantiate the example classes
                    sim = ExampleAVSimulator(**sim_args)
                    reward_function = ExampleAVReward(**reward_args)
                    spaces = ExampleAVSpaces(**spaces_args)

                    # Create the environment
                    if 'id' in env_args:
                        env_args.pop('id')
                    env = ASTEnv(simulator=sim,
                                 reward_function=reward_function,
                                 spaces=spaces,
                                 **env_args)

                    top_paths = BPQ.BoundedPriorityQueue(**bpq_args)

                    if mcts_type == 'mcts':
                        print('mcts')
                        algo = MCTS(env=env, top_paths=top_paths, **algo_args)
                    elif mcts_type == 'mctsbv':
                        print('mctsbv')
                        algo = MCTSBV(env=env,
                                      top_paths=top_paths,
                                      **algo_args)
                    elif mcts_type == 'mctsrs':
                        print('mctsrs')
                        algo = MCTSRS(env=env,
                                      top_paths=top_paths,
                                      **algo_args)
                    else:
                        raise NotImplementedError

                    sampler_cls = ASTVectorizedSampler

                    local_runner.setup(algo=algo,
                                       env=env,
                                       sampler_cls=sampler_cls,
                                       sampler_args={
                                           "open_loop": False,
                                           "sim": sim,
                                           "reward_function": reward_function,
                                           "n_envs": n_parallel
                                       })

                    # Run the experiment
                    local_runner.train(**runner_args)

                    log_dir = run_experiment_args['log_dir']
                    with open(log_dir + '/best_actions.p', 'rb') as f:
                        best_actions = pickle.load(f)
                    expert_trajectories = []
                    for actions in best_actions:
                        sim.reset(s_0=env_args['s_0'])
                        path = []
                        for action in actions:
                            obs = sim.step(action)
                            state = sim.clone_state()
                            reward = reward_function.give_reward(
                                action=action, info=sim.get_reward_info())
                            path.append({
                                'state': state,
                                'reward': reward,
                                'action': action,
                                'observation': obs
                            })
                        expert_trajectories.append(path)
                    with open(log_dir + '/expert_trajectory.p', 'wb') as f:
                        pickle.dump(expert_trajectories, f)
class ASTEnv(gym.Env):
    r""" Gym environment to turn general AST tasks into garage compatible problems.

    Parameters
    ----------
    open_loop : bool
        True if the simulation is open-loop, meaning that AST must generate all actions ahead of time, instead
        of being able to output an action in sync with the simulator, getting an observation back before
        the next action is generated. False to get interactive control, which requires that `blackbox_sim_state`
        is also False.
    blackbox_sim_state : bool
        True if the true simulation state can not be observed, in which case actions and the initial conditions are
        used as the observation. False if the simulation state can be observed, in which case it will be used.
    fixed_init_state : bool
        True if the initial state is fixed, False to sample the initial state for each rollout from the observaation
        space.
    s_0 : array_like
        The initial state for the simulation (ignored if `fixed_init_state` is False)
    simulator : :py:class:`ast_toolbox.simulators.ASTSimulator`
        The simulator wrapper, inheriting from `ast_toolbox.simulators.ASTSimulator`.
    reward_function : :py:class:`ast_toolbox.rewards.ASTReward`
        The reward function, inheriting from `ast_toolbox.rewards.ASTReward`.
    spaces : :py:class:`ast_toolbox.spaces.ASTSpaces`
        The observation and action space definitions, inheriting from `ast_toolbox.spaces.ASTSpaces`.
    """

    def __init__(self,
                 open_loop=True,
                 blackbox_sim_state=True,
                 fixed_init_state=False,
                 s_0=None,
                 simulator=None,
                 reward_function=None,
                 spaces=None):

        # Constant hyper-params -- set by user
        self.open_loop = open_loop
        self.blackbox_sim_state = blackbox_sim_state  # is this redundant?
        self.spaces = spaces
        # These are set by reset, not the user
        self._done = False
        self._reward = 0.0
        self._step = 0
        self._action = None
        self._actions = []
        self._first_step = True
        self.reward_range = (-float('inf'), float('inf'))
        self.metadata = None
        self._cum_reward = 0.0

        if s_0 is None:
            self._init_state = self.observation_space.sample()
        else:
            self._init_state = s_0
        self._fixed_init_state = fixed_init_state
        self.simulator = simulator
        if self.simulator is None:
            self.simulator = ExampleAVSimulator()
        self.reward_function = reward_function
        if self.reward_function is None:
            self.reward_function = ExampleAVReward()

        if hasattr(self.simulator, "vec_env_executor") and callable(getattr(self.simulator, "vec_env_executor")):
            self.vectorized = True
        else:
            self.vectorized = False
        # super().__init__(self)

    def step(self, action):
        r"""
        Run one timestep of the environment's dynamics. When end of episode
        is reached, reset() should be called to reset the environment's internal state.

        Parameters
        ----------
        action : array_like
            An action provided by the environment.

        Returns
        -------
        : :py:func:`garage.envs.base.Step`
            A step in the rollout.
            Contains the following information:
                - observation (array_like): Agent's observation of the current environment.
                - reward (float): Amount of reward due to the previous action.
                - done (bool): Is the current step a terminal or goal state, ending the rollout.
                - actions (array_like): The action taken at the current.
                - state (array_like): The cloned simulation state at the current cell, used for resetting if chosen to start a rollout.
                - is_terminal (bool): Whether or not the current cell is a terminal state.
                - is_goal (bool): Whether or not the current cell is a goal state.
        """
        self._env_state_before_action = self._env_state.copy()

        self._action = action
        self._actions.append(action)
        action_return = self._action

        # Update simulation step
        obs = self.simulator.step(self._action)
        if (obs is None) or (self.open_loop is True) or (self.blackbox_sim_state):
            obs = np.array(self._init_state)

        if self.simulator.is_terminal() or self.simulator.is_goal():
            self._done = True

        # Calculate the reward for this step
        self._reward = self.reward_function.give_reward(
            action=self._action,
            info=self.simulator.get_reward_info())
        self._cum_reward += self._reward

        # Update instance attributes
        self._step = self._step + 1

        self._simulator_state = self.simulator.clone_state()
        self._env_state = np.concatenate((self._simulator_state,
                                          np.array([self._cum_reward]),
                                          np.array([self._step])),
                                         axis=0)

        return Step(observation=obs,
                    reward=self._reward,
                    done=self._done,
                    actions=action_return,
                    state=self._env_state_before_action,
                    is_terminal=self.simulator.is_terminal(),
                    is_goal=self.simulator.is_goal())

    def simulate(self, actions):
        r"""Run a full simulation rollout.

        Parameters
        ----------
        actions : list[array_like]
            A list of array_likes, where each member is the action taken at that step.

        Returns
        -------
        int
            The step of the trajectory where a collision was found, or -1 if a collision was not found.
        dict
            A dictionary of simulation information for logging and diagnostics.
        """
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        return self.simulator.simulate(actions, self._init_state)

    def reset(self):
        r"""Resets the state of the environment, returning an initial observation.

        Returns
        -------
        observation : array_like
            The initial observation of the space. (Initial reward is assumed to be 0.)
        """
        self._actions = []
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        self._done = False
        self._reward = 0.0
        self._cum_reward = 0.0
        self._action = None
        self._actions = []
        self._first_step = True
        self._step = 0

        obs = np.array(self.simulator.reset(self._init_state))
        if not self._fixed_init_state:
            obs = np.concatenate((obs, np.array(self._init_state)), axis=0)

        self._simulator_state = self.simulator.clone_state()
        self._env_state = np.concatenate((self._simulator_state,
                                          np.array([self._cum_reward]),
                                          np.array([self._step])),
                                         axis=0)

        return obs

    @property
    def action_space(self):
        r"""Convenient access to the environment's action space.

        Returns
        -------
        : `gym.spaces.Space <https://gym.openai.com/docs/#spaces>`_
            The action space of the reinforcement learning problem.
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.action_space)
            return self.simulator.action_space
        else:
            return self.spaces.action_space

    @property
    def observation_space(self):
        r"""Convenient access to the environment's observation space.

        Returns
        -------
        : `gym.spaces.Space <https://gym.openai.com/docs/#spaces>`_
            The observation space of the reinforcement learning problem.
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.observation_space)
            return self.simulator.observation_space
        else:
            return self.spaces.observation_space

    def log(self):
        r"""Calls the simulator's `log` function.

        """
        self.simulator.log()

    def render(self, **kwargs):
        r"""Calls the simulator's `render` function, if it exists.

        Parameters
        ----------
        kwargs :
            Keyword arguments used in the simulators `render` function.

        Returns
        -------
        None or object
            Returns the output of the simulator's `render` function, or None if the simulator has no `render` function.
        """
        if hasattr(self.simulator, "render") and callable(getattr(self.simulator, "render")):
            return self.simulator.render(**kwargs)
        else:
            return None

    def close(self):
        r"""Calls the simulator's `close` function, if it exists.

        Returns
        -------
        None or object
            Returns the output of the simulator's `close` function, or None if the simulator has no `close` function.
        """
        if hasattr(self.simulator, "close") and callable(getattr(self.simulator, "close")):
            return self.simulator.close()
        else:
            return None

    @cached_property
    def spec(self):
        r"""Returns a garage environment specification.

        Returns
        -------
        :py:class:`garage.envs.env_spec.EnvSpec`
            A garage environment specification.
        """
        return EnvSpec(
            observation_space=self.observation_space,
            action_space=self.action_space)
class GoExploreASTEnv(gym.Env, Parameterized):
    r"""Gym environment to turn general AST tasks into garage compatible problems with Go-Explore style resets.

    Certain algorithms, such as Go-Explore and the Backwards Algorithm, require deterministic resets of the
    simulator. `GoExploreASTEnv` handles this by cloning simulator states and saving them in a cell structure. The
    cells are then stored in a hashed database.

    Parameters
    ----------
    open_loop : bool
        True if the simulation is open-loop, meaning that AST must generate all actions ahead of time, instead
        of being able to output an action in sync with the simulator, getting an observation back before
        the next action is generated. False to get interactive control, which requires that `blackbox_sim_state`
        is also False.
    blackbox_sim_state : bool
        True if the true simulation state can not be observed, in which case actions and the initial conditions are
        used as the observation. False if the simulation state can be observed, in which case it will be used
    fixed_init_state : bool
        True if the initial state is fixed, False to sample the initial state for each rollout from the observaation
        space.
    s_0 : array_like
        The initial state for the simulation (ignored if `fixed_init_state` is False)
    simulator : :py:class:`ast_toolbox.simulators.ASTSimulator`
        The simulator wrapper, inheriting from `ast_toolbox.simulators.ASTSimulator`.
    reward_function : :py:class:`ast_toolbox.rewards.ASTReward`
        The reward function, inheriting from `ast_toolbox.rewards.ASTReward`.
    spaces : :py:class:`ast_toolbox.spaces.ASTSpaces`
        The observation and action space definitions, inheriting from `ast_toolbox.spaces.ASTSpaces`.
    """

    def __init__(self,
                 open_loop=True,
                 blackbox_sim_state=True,
                 fixed_init_state=False,
                 s_0=None,
                 simulator=None,
                 reward_function=None,
                 spaces=None):

        # gym_env = gym.make('ast_toolbox:GoExploreAST-v0', {'test':'test string'})
        # pdb.set_trace()
        # super().__init__(gym_env)
        # Constant hyper-params -- set by user
        self.open_loop = open_loop
        self.blackbox_sim_state = blackbox_sim_state  # is this redundant?
        self.spaces = spaces
        if spaces is None:
            self.spaces = ExampleAVSpaces()
        # These are set by reset, not the user
        self._done = False
        self._reward = 0.0
        self._info = {}
        self._step = 0
        self._action = None
        self._actions = []
        self._first_step = True
        self.reward_range = (-float('inf'), float('inf'))
        self.metadata = None
        self.spec._entry_point = []
        self._cum_reward = 0.0
        self.root_action = None
        self.sample_limit = 10000

        self.simulator = simulator
        if self.simulator is None:
            self.simulator = ExampleAVSimulator()

        if s_0 is None:
            self._init_state = self.observation_space.sample()
        else:
            self._init_state = s_0
        self._fixed_init_state = fixed_init_state

        self.reward_function = reward_function
        if self.reward_function is None:
            self.reward_function = ExampleAVReward()

        if hasattr(self.simulator, "vec_env_executor") and callable(getattr(self.simulator, "vec_env_executor")):
            self.vectorized = True
        else:
            self.vectorized = False
        # super().__init__(self)
        # Always call Serializable constructor last
        self.params_set = False
        self.db_filename = 'database.dat'
        self.key_list = []
        self.max_value = 0
        self.robustify_state = []
        self.robustify = False

        Parameterized.__init__(self)

    def sample(self, population):
        r"""Sample a cell from the cell pool with likelihood proportional to cell fitness.

        The sampling is done using Stochastic Acceptance [1]_, with inspiration from John B Nelson's blog [2]_.

        The sampler rejects cells until the acceptance criterea is met. If the maximum number of rejections is
        exceeded, the sampler then will sample uniformly sample a cell until it finds a cell with fitness > 0. If
        the second sampling phase also exceeds the rejection limit, then the function raises an exception.

        Parameters
        ----------
        population : list
            A list containing the population of cells to sample from.

        Returns
        -------
        object
            The sampled cell.

        Raises
        ------
        ValueError
            If the maximum number of rejections is exceeded in both the proportional and the uniform sampling phases.

        References
        ----------
        .. [1] Lipowski, Adam, and Dorota Lipowska. "Roulette-wheel selection via stochastic acceptance."
        Physica A: Statistical Mechanics and its Applications 391.6 (2012): 2193-2196.
        `<https://arxiv.org/pdf/1109.3627.pdf>`_
        .. [2] `<https://jbn.github.io/fast_proportional_selection/>`_
        """
        attempts = 0
        while attempts < self.sample_limit:
            attempts += 1
            candidate = population[random.choice(self.p_key_list.value)]
            if random.random() < (candidate.fitness / self.p_max_value.value):
                return candidate
        attempts = 0
        while attempts < self.sample_limit:
            attempts += 1
            candidate = population[random.choice(self.p_key_list.value)]
            if candidate.fitness > 0:
                print("Returning Uniform Random Sample - Max Attempts Reached!")
                return candidate
        print("Failed to find a valid state for reset!")
        raise ValueError
        # return population[random.choice(self.p_key_list.value)]

    def get_first_cell(self):
        r"""Returns a the observation and state of the initial state, to be used for a root cell.

        Returns
        -------
        obs : array_like
            Agent's observation of the current environment.
        state : array_like
            The cloned simulation state at the current cell, used for resetting if chosen to start a rollout.
        """

        obs = self.env_reset()
        if self.blackbox_sim_state:
            obs = self.simulator.get_first_action()

        state = np.concatenate((self.simulator.clone_state(),
                                np.array([self._cum_reward]),
                                np.array([-1])),
                               axis=0)

        return obs, state

    def step(self, action):
        r"""
        Run one timestep of the environment's dynamics. When end of episode
        is reached, reset() should be called to reset the environment's internal state.

        Parameters
        ----------
        action : array_like
            An action provided by the environment.

        Returns
        -------
        : :py:func:`garage.envs.base.Step`
            A step in the rollout.
            Contains the following information:
                - observation (array_like): Agent's observation of the current environment.
                - reward (float): Amount of reward due to the previous action.
                - done (bool): Is the current step a terminal or goal state, ending the rollout.
                - cache (dict): A dictionary containing other diagnostic information from the current step.
                - actions (array_like): The action taken at the current.
                - state (array_like): The cloned simulation state at the current cell, used for resetting if chosen to start a rollout.
                - is_terminal (bool): Whether or not the current cell is a terminal state.
                - is_goal (bool): Whether or not the current cell is a goal state.
        """
        self._env_state_before_action = self._env_state.copy()

        self._action = action
        self._actions.append(action)
        action_return = self._action

        # Update simulation step
        obs = self.simulator.step(self._action)
        if (obs is None) or (self.open_loop is True) or (self.blackbox_sim_state):

            obs = np.array(self._init_state)

        # Add step number to differentiate identical actions
        if self.simulator.is_terminal() or self.simulator.is_goal():
            self._done = True

        # Calculate the reward for this step
        self._reward = self.reward_function.give_reward(
            action=self._action,
            info=self.simulator.get_reward_info())
        self._cum_reward += self._reward

        # Update instance attributes
        self._step = self._step + 1
        self._simulator_state = self.simulator.clone_state()
        self._env_state = np.concatenate((self._simulator_state,
                                          np.array([self._cum_reward]),
                                          np.array([self._step])),
                                         axis=0)

        return Step(observation=obs,
                    reward=self._reward,
                    done=self._done,
                    cache=self._info,
                    actions=action_return,
                    state=self._env_state_before_action,
                    root_action=self.root_action,
                    is_terminal=self.simulator.is_terminal(),
                    is_goal=self.simulator.is_goal())

    def simulate(self, actions):
        r"""Run a full simulation rollout.

        Parameters
        ----------
        actions : list[array_like]
            A list of array_likes, where each member is the action taken at that step.

        Returns
        -------
        int
            The step of the trajectory where a collision was found, or -1 if a collision was not found.
        dict
            A dictionary of simulation information for logging and diagnostics.
        """
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        return self.simulator.simulate(actions, self._init_state)

    def reset(self, **kwargs):
        r"""Resets the state of the environment, returning an initial observation.

        The reset has 2 modes.

        In the "robustify" mode (self.p_robustify_state.value is not None), the simulator resets
        the environment to `p_robustify_state.value`. It then returns the initial condition.

        In the "Go-Explore" mode, the environment attempts to sample a cell from the cell pool. If successful,
        the simulator is reset to the cell's state. On an error, the environment is reset to the intial state.

        Returns
        -------
        observation : array_like
            The initial observation of the space. (Initial reward is assumed to be 0.)
        """

        try:
            # print(self.p_robustify_state.value)
            if self.p_robustify_state is not None and self.p_robustify_state.value is not None and len(
                    self.p_robustify_state.value) > 0:
                state = self.p_robustify_state.value
                # print('-----------Robustify Init-----------------')
                # print('-----------Robustify Init: ', state, ' -----------------')
                self.simulator.restore_state(state[:-2])
                obs = self.simulator.observation_return()
                self._done = False
                self._cum_reward = state[-2]
                self._step = state[-1]
                # pdb.set_trace()

                self.robustify = True

                self._simulator_state = self.simulator.clone_state()
                self._env_state = np.concatenate((self._simulator_state,
                                                  np.array([self._cum_reward]),
                                                  np.array([self._step])),
                                                 axis=0)
                return self._init_state

            flag = db.DB_RDONLY
            pool_DB = db.DB()
            pool_DB.open(self.p_db_filename.value, dbname=None, dbtype=db.DB_HASH, flags=flag)
            dd_pool = shelve.Shelf(pool_DB, protocol=pickle.HIGHEST_PROTOCOL)
            cell = self.sample(dd_pool)
            dd_pool.close()
            pool_DB.close()

            if cell.state is not None:
                # pdb.set_trace()
                if np.all(cell.state == 0):
                    print("-------DEFORMED CELL STATE-------")
                    obs = self.env_reset()
                else:
                    self.simulator.restore_state(cell.state[:-2])
                    if self.simulator.is_terminal() or self.simulator.is_goal():
                        print('-------SAMPLED TERMINAL STATE-------')
                        pdb.set_trace()
                        obs = self.env_reset()

                    else:
                        if cell.score == 0.0 and cell.parent is not None:
                            print("Reset to cell with score 0.0 ---- terminal: ", self.simulator.is_terminal(),
                                  " goal: ", self.simulator.is_goal(), " obs: ", cell.observation)
                        obs = self.simulator.observation_return()
                        self._done = False
                        self._cum_reward = cell.state[-2]
                        self._step = cell.state[-1]
                        self.root_action = cell.observation
            else:
                print("Reset from start")
                obs = self.env_reset()

            self._simulator_state = self.simulator.clone_state()
            self._env_state = np.concatenate((self._simulator_state,
                                              np.array([self._cum_reward]),
                                              np.array([self._step])),
                                             axis=0)
            # pdb.set_trace()
        except db.DBBusyError:
            print("DBBusyError")
            obs = self.env_reset()
        except db.DBLockNotGrantedError or db.DBLockDeadlockError:
            print("db.DBLockNotGrantedError or db.DBLockDeadlockError")
            obs = self.env_reset()
        except db.DBForeignConflictError:
            print("DBForeignConflictError")
            obs = self.env_reset()
        except db.DBAccessError:
            print("DBAccessError")
            obs = self.env_reset()
        except db.DBPermissionsError:
            print("DBPermissionsError")
            obs = self.env_reset()
        except db.DBNoSuchFileError:
            print("DBNoSuchFileError")
            obs = self.env_reset()
        except db.DBError:
            print("DBError")
            obs = self.env_reset()
        except BaseException:
            print("Failed to get state from database")
            pdb.set_trace()
            obs = self.env_reset()

        return obs

    def env_reset(self):
        r"""Resets the state of the environment, returning an initial observation.

        Returns
        -------
        observation : array_like
            The initial observation of the space. (Initial reward is assumed to be 0.)
        """
        self._actions = []
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        self._done = False
        self._reward = 0.0
        self._cum_reward = 0.0
        self._info = {'actions': []}
        self._action = self.simulator.get_first_action()
        self._actions = []
        self._first_step = True
        self._step = 0
        obs = np.array(self.simulator.reset(self._init_state))

        if not self.blackbox_sim_state:
            obs = np.concatenate((obs, np.array(self._init_state)), axis=0)

        self.root_action = self._action

        return obs

    @property
    def action_space(self):
        r"""Convenient access to the environment's action space.

        Returns
        -------
        : `gym.spaces.Space <https://gym.openai.com/docs/#spaces>`_
            The action space of the reinforcement learning problem.
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.action_space)
            return self.simulator.action_space
        else:
            return self.spaces.action_space

    @property
    def observation_space(self):
        r"""Convenient access to the environment's observation space.

        Returns
        -------
        : `gym.spaces.Space <https://gym.openai.com/docs/#spaces>`_
            The observation space of the reinforcement learning problem.
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.observation_space)
            return self.simulator.observation_space
        else:
            return self.spaces.observation_space

    def get_cache_list(self):
        """Returns the environment info cache.

        Returns
        -------
        dict
            A dictionary containing diagnostic and logging information for the environment.
        """
        return self._info

    def log(self):
        r"""Calls the simulator's `log` function.

        """
        self.simulator.log()

    def render(self, **kwargs):
        r"""Calls the simulator's `render` function, if it exists.

        Returns
        -------
        None or object
            Returns the output of the simulator's `render` function, or None if the simulator has no `render` function.
        """
        if hasattr(self.simulator, "render") and callable(getattr(self.simulator, "render")):
            return self.simulator.render(**kwargs)
        else:
            return None

    def close(self):
        r"""Calls the simulator's `close` function, if it exists.

        Returns
        -------
        None or object
            Returns the output of the simulator's `close` function, or None if the simulator has no `close` function.
        """
        if hasattr(self.simulator, "close") and callable(getattr(self.simulator, "close")):
            self.simulator.close()
        else:
            return None

    @cached_property
    def spec(self):
        r"""Returns a garage environment specification.

        Returns
        -------
        :py:class:`garage.envs.env_spec.EnvSpec`
            A garage environment specification.
        """
        return EnvSpec(
            observation_space=self.observation_space,
            action_space=self.action_space)

    def get_params_internal(self, **tags):
        r"""Returns the parameters associated with the given tags.

        Parameters
        ----------
        tags : dict[bool]
            For each tag, a parameter is returned if the parameter name matches the tag's key
        Returns
        -------
        list
            List of parameters
        """
        # this lasagne function also returns all var below the passed layers
        if not self.params_set:
            self.p_db_filename = GoExploreParameter("db_filename", self.db_filename)
            self.p_key_list = GoExploreParameter("key_list", self.key_list)
            self.p_max_value = GoExploreParameter("max_value", self.max_value)
            self.p_robustify_state = GoExploreParameter("robustify_state", self.robustify_state)
            self.params_set = True

        if tags.pop("db_filename", False):
            return [self.p_db_filename]

        if tags.pop("key_list", False):
            return [self.p_key_list]

        if tags.pop("max_value", False):
            return [self.p_max_value]

        if tags.pop("robustify_state", False):
            return [self.p_robustify_state]

        return [self.p_db_filename, self.p_key_list, self.p_max_value, self.p_robustify_state]  # , self.p_downsampler]

    def set_param_values(self, param_values, **tags):
        r"""Set the values of parameters

        Parameters
        ----------
        param_values : object
            Value to set the parameter to.
        tags : dict[bool]
            For each tag, a parameter is returned if the parameter name matches the tag's key
        """
        debug = tags.pop("debug", False)

        for param, value in zip(
                self.get_params(**tags),
                param_values):
            param.set_value(value)
            if debug:
                print("setting value of %s" % param.name)

    def get_param_values(self, **tags):
        """Return the values of internal parameters.

        Parameters
        ----------
        tags : dict[bool]
            For each tag, a parameter is returned if the parameter name matches the tag's key

        Returns
        -------
        list
            A list of parameter values.
        """
        return [
            param.get_value(borrow=True) for param in self.get_params(**tags)
        ]

    def downsample(self, obs):
        """Create a downsampled approximation of the observed simulation state.

        Parameters
        ----------
        obs : array_like
            The observed simulation state.

        Returns
        -------
        array_like
            The downsampled approximation of the observed simulation state.
        """
        return obs
Esempio n. 5
0
class GoExploreASTEnv(gym.Env, Parameterized):
    def __init__(self,
                 open_loop=True,
                 blackbox_sim_state=True,
                 fixed_init_state=False,
                 s_0=None,
                 simulator=None,
                 reward_function=None,
                 spaces=None):

        # gym_env = gym.make('ast_toolbox:GoExploreAST-v0', {'test':'test string'})
        # pdb.set_trace()
        # super().__init__(gym_env)
        # Constant hyper-params -- set by user
        self.open_loop = open_loop
        self.blackbox_sim_state = blackbox_sim_state  # is this redundant?
        self.spaces = spaces
        if spaces is None:
            self.spaces = ExampleAVSpaces()
        # These are set by reset, not the user
        self._done = False
        self._reward = 0.0
        self._info = {}
        self._step = 0
        self._action = None
        self._actions = []
        self._first_step = True
        self.reward_range = (-float('inf'), float('inf'))
        self.metadata = None
        self.spec._entry_point = []
        self._cum_reward = 0.0
        self.root_action = None
        self.sample_limit = 10000

        self.simulator = simulator
        if self.simulator is None:
            self.simulator = ExampleAVSimulator()

        if s_0 is None:
            self._init_state = self.observation_space.sample()
        else:
            self._init_state = s_0
        self._fixed_init_state = fixed_init_state

        self.reward_function = reward_function
        if self.reward_function is None:
            self.reward_function = ExampleAVReward()

        if hasattr(self.simulator, "vec_env_executor") and callable(
                getattr(self.simulator, "vec_env_executor")):
            self.vectorized = True
        else:
            self.vectorized = False
        # super().__init__(self)
        # Always call Serializable constructor last
        self.params_set = False
        self.db_filename = 'database.dat'
        self.key_list = []
        self.max_value = 0
        self.robustify_state = []
        self.robustify = False

        Parameterized.__init__(self)

    def sample(self, population):
        # Proportional sampling: Stochastic Acceptance
        # https://arxiv.org/pdf/1109.3627.pdf
        # https://jbn.github.io/fast_proportional_selection/
        attempts = 0
        while attempts < self.sample_limit:
            attempts += 1
            candidate = population[random.choice(self.p_key_list.value)]
            if random.random() < (candidate.fitness / self.p_max_value.value):
                return candidate
        attempts = 0
        while attempts < self.sample_limit:
            attempts += 1
            candidate = population[random.choice(self.p_key_list.value)]
            if candidate.fitness > 0:
                print(
                    "Returning Uniform Random Sample - Max Attempts Reached!")
                return candidate
        print("Failed to find a valid state for reset!")
        raise ValueError
        # return population[random.choice(self.p_key_list.value)]

    def get_first_cell(self):
        # obs = self.env.env.reset()
        # state = self.env.env.clone_state()
        obs = self.env_reset()
        if self.blackbox_sim_state:
            # obs = self.downsample(self.simulator.get_first_action())
            obs = self.simulator.get_first_action()
            # else:
        #     obs = self.env_reset()
        # obs = self.simulator.reset(self._init_state)
        state = np.concatenate(
            (self.simulator.clone_state(), np.array([self._cum_reward
                                                     ]), np.array([-1])),
            axis=0)
        # pdb.set_trace()
        return obs, state

    def step(self, action):
        """
        Run one timestep of the environment's dynamics. When end of episode
        is reached, reset() should be called to reset the environment's internal state.
        Input
        -----
        action : an action provided by the environment
        Outputs
        -------
        (observation, reward, done, info)
        observation : agent's observation of the current environment
        reward [Float] : amount of reward due to the previous action
        done : a boolean, indicating whether the episode has ended
        info : a dictionary containing other diagnostic information from the previous action
        """
        self._action = action
        self._actions.append(action)
        action_return = self._action
        # Update simulation step
        obs = self.simulator.step(self._action)
        if (obs is None) or (self.open_loop is
                             True) or (self.blackbox_sim_state):
            # print('Open Loop:', obs)
            obs = np.array(self._init_state)
            # if not self.robustify:
            # action_return = self.downsample(action_return)
        # if self.simulator.is_goal():
        # Add step number to differentiate identical actions
        # obs = np.concatenate((np.array([self._step]), self.downsample(obs)), axis=0)
        if self.simulator.is_terminal() or self.simulator.is_goal():
            self._done = True
        # Calculate the reward for this step
        self._reward = self.reward_function.give_reward(
            action=self._action, info=self.simulator.get_reward_info())
        self._cum_reward += self._reward
        # Update instance attributes
        self._step = self._step + 1
        self._simulator_state = self.simulator.clone_state()
        self._env_state = np.concatenate(
            (self._simulator_state, np.array([self._cum_reward
                                              ]), np.array([self._step])),
            axis=0)

        # if self.robustify:
        #     # No obs?
        #     # pdb.set_trace()
        #     obs = self._init_state
        # else:
        #     # print(self.robustify_state)
        #     obs = self.downsample(obs)

        # pdb.set_trace()

        return Step(
            observation=obs,
            reward=self._reward,
            done=self._done,
            cache=self._info,
            actions=action_return,
            # step = self._step -1,
            # real_actions=self._action,
            state=self._env_state,
            root_action=self.root_action,
            is_terminal=self.simulator.is_terminal(),
            is_goal=self.simulator.is_goal())

    def simulate(self, actions):
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        return self.simulator.simulate(actions, self._init_state)

    def reset(self, **kwargs):
        """
        This method is necessary to suppress a deprecated warning
        thrown by gym.Wrapper.

        Calls reset on wrapped env.
        """

        try:
            # print(self.p_robustify_state.value)
            if self.p_robustify_state is not None and self.p_robustify_state.value is not None and len(
                    self.p_robustify_state.value) > 0:
                state = self.p_robustify_state.value
                print('-----------Robustify Init-----------------')
                print('-----------Robustify Init: ', state,
                      ' -----------------')
                self.simulator.restore_state(state[:-2])
                obs = self.simulator._get_obs()
                self._done = False
                self._cum_reward = state[-2]
                self._step = state[-1]
                # pdb.set_trace()

                self.robustify = True
                return self._init_state
            # pdb.set_trace()
            # start = time.time()
            flag = db.DB_RDONLY
            pool_DB = db.DB()
            # tick1 = time.time()
            pool_DB.open(self.p_db_filename.value,
                         dbname=None,
                         dbtype=db.DB_HASH,
                         flags=flag)
            # tick2 = time.time()
            dd_pool = shelve.Shelf(pool_DB, protocol=pickle.HIGHEST_PROTOCOL)
            # tick3 = time.time()
            # keys = dd_pool.keys()
            # tick4_1 = time.time()
            # list_of_keys = list(keys)
            # tick4_2 = time.time()
            # choice = random.choice(self.p_key_list.value)
            # import pdb; pdb.set_trace()
            # tick4_3 = time.time()
            # cell = dd_pool[choice]
            cell = self.sample(dd_pool)
            # tick5 = time.time()
            dd_pool.close()
            # tick6 = time.time()
            pool_DB.close()
            # tick7 = time.time()
            # print("Make DB: ", 100*(tick1 - start)/(tick7 - start), " %")
            # print("Open DB: ", 100*(tick2 - tick1) / (tick7 - start), " %")
            # print("Open Shelf: ", 100*(tick4_2 - tick2) / (tick7 - start), " %")
            # # print("Get all keys: ", 100*(tick4_1 - tick3) / (tick7 - start), " %")
            # # print("Make list of all keys: ", 100 * (tick4_2 - tick4_1) / (tick7 - start), " %")
            # print("Choose random cell: ", 100 * (tick4_3 - tick4_2) / (tick7 - start), " %")
            # print("Get random cell: ", 100*(tick5 - tick4_3) / (tick7 - start), " %")
            # print("Close shelf: ", 100*(tick6 - tick5) / (tick7 - start), " %")
            # print("Close DB: ", 100*(tick7 - tick6) / (tick7 - start), " %")
            # print("DB Access took: ", time.time() - start, " s")
            if cell.state is not None:
                # pdb.set_trace()
                if np.all(cell.state == 0):
                    print("-------DEFORMED CELL STATE-------")
                    obs = self.env_reset()
                else:
                    # print("restore state: ", cell.state)
                    self.simulator.restore_state(cell.state[:-2])
                    if self.simulator.is_terminal() or self.simulator.is_goal(
                    ):
                        print('-------SAMPLED TERMINAL STATE-------')
                        pdb.set_trace()
                        obs = self.env_reset()

                    else:
                        # print("restored")
                        if cell.score == 0.0 and cell.parent is not None:
                            print(
                                "Reset to cell with score 0.0 ---- terminal: ",
                                self.simulator.is_terminal(), " goal: ",
                                self.simulator.is_goal(), " obs: ",
                                cell.observation)
                        obs = self.simulator._get_obs()
                        self._done = False
                        self._cum_reward = cell.state[-2]
                        self._step = cell.state[-1]
                        self.root_action = cell.observation
                    # print("restore obs: ", obs)
            else:
                print("Reset from start")
                obs = self.env_reset()
            # pdb.set_trace()
        except db.DBBusyError:
            print("DBBusyError")
            obs = self.env_reset()
        except db.DBLockNotGrantedError or db.DBLockDeadlockError:
            print("db.DBLockNotGrantedError or db.DBLockDeadlockError")
            obs = self.env_reset()
        except db.DBForeignConflictError:
            print("DBForeignConflictError")
            obs = self.env_reset()
        except db.DBAccessError:
            print("DBAccessError")
            obs = self.env_reset()
        except db.DBPermissionsError:
            print("DBPermissionsError")
            obs = self.env_reset()
        except db.DBNoSuchFileError:
            print("DBNoSuchFileError")
            obs = self.env_reset()
        except db.DBError:
            print("DBError")
            obs = self.env_reset()
        except BaseException:
            print("Failed to get state from database")
            pdb.set_trace()
            obs = self.env_reset()

        return obs

    def env_reset(self):
        """
        Resets the state of the environment, returning an initial observation.
        Outputs
        -------
        observation : the initial observation of the space. (Initial reward is assumed to be 0.)
        """
        self._actions = []
        if not self._fixed_init_state:
            self._init_state = self.observation_space.sample()
        self._done = False
        self._reward = 0.0
        self._cum_reward = 0.0
        self._info = {'actions': []}
        self._action = self.simulator.get_first_action()
        self._actions = []
        self._first_step = True
        self._step = 0
        obs = np.array(self.simulator.reset(self._init_state))
        # if self.blackbox_sim_state:
        #     obs = np.array([0] * self.action_space.shape[0])
        # else:
        #     print('Not action only')
        if not self.blackbox_sim_state:
            obs = np.concatenate((obs, np.array(self._init_state)), axis=0)

        # self.root_action = self.downsample(self._action)
        self.root_action = self._action

        # obs = np.concatenate((np.array([self._step]), self.downsample(obs)), axis=0)
        return obs

    @property
    def action_space(self):
        """
        Returns a Space object
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.action_space)
            return self.simulator.action_space
        else:
            return self.spaces.action_space

    @property
    def observation_space(self):
        """
        Returns a Space object
        """
        if self.spaces is None:
            # return self._to_garage_space(self.simulator.observation_space)
            return self.simulator.observation_space
        else:
            return self.spaces.observation_space

    def get_cache_list(self):
        return self._info

    def log(self):
        self.simulator.log()

    def render(self, **kwargs):
        if hasattr(self.simulator, "render") and callable(
                getattr(self.simulator, "render")):
            return self.simulator.render(**kwargs)
        else:
            return None

    def close(self):
        if hasattr(self.simulator, "close") and callable(
                getattr(self.simulator, "close")):
            self.simulator.close()
        else:
            return None

    def vec_env_executor(self, n_envs, max_path_length):
        return self.simulator.vec_env_executor(n_envs, max_path_length,
                                               self.reward_function,
                                               self._fixed_init_state,
                                               self._init_state,
                                               self.open_loop)

    def log_diagnostics(self, paths):
        pass

    @cached_property
    def spec(self):
        """
        Returns an EnvSpec.

        Returns:
            spec (garage.envs.EnvSpec)
        """
        return EnvSpec(observation_space=self.observation_space,
                       action_space=self.action_space)

    def get_params_internal(self, **tags):
        # this lasagne function also returns all var below the passed layers
        if not self.params_set:
            self.p_db_filename = GoExploreParameter("db_filename",
                                                    self.db_filename)
            self.p_key_list = GoExploreParameter("key_list", self.key_list)
            self.p_max_value = GoExploreParameter("max_value", self.max_value)
            self.p_robustify_state = GoExploreParameter(
                "robustify_state", self.robustify_state)
            self.params_set = True

        if tags.pop("db_filename", False):
            return [self.p_db_filename]

        if tags.pop("key_list", False):
            return [self.p_key_list]

        if tags.pop("max_value", False):
            return [self.p_max_value]

        if tags.pop("robustify_state", False):
            return [self.p_robustify_state]

        return [
            self.p_db_filename, self.p_key_list, self.p_max_value,
            self.p_robustify_state
        ]  # , self.p_downsampler]

    def set_param_values(self, param_values, **tags):
        debug = tags.pop("debug", False)

        for param, value in zip(self.get_params(**tags), param_values):
            param.set_value(value)
            if debug:
                print("setting value of %s" % param.name)

    def get_param_values(self, **tags):
        return [
            param.get_value(borrow=True) for param in self.get_params(**tags)
        ]

    def downsample(self, obs):
        return obs

    def _get_obs(self):
        return self.simulator._get_obs()