class NASimEnv(gym.Env):
    """ A simulated computer network environment for pen-testing.

    Implements the OpenAI gym interface.

    ...

    Attributes
    ----------
    name : str
        the environment scenario name
    scenario : Scenario
        Scenario object, defining the properties of the environment
    action_space : FlatActionSpace or ParameterisedActionSpace
        Action space for environment.
        If *flat_action=True* then this is a discrete action space (which
        subclasses gym.spaces.Discrete), so each action is represented by an
        integer.
        If *flat_action=False* then this is a parameterised action space (which
        subclasses gym.spaces.MultiDiscrete), so each action is represented
        using a list of parameters.
    observation_space : gym.spaces.Box
        observation space for environment.
        If *flat_obs=True* then observations are represented by a 1D vector,
        otherwise observations are represented as a 2D matrix.
    current_state : State
        the current state of the environment
    last_obs : Observation
        the last observation that was generated by environment
    steps : int
        the number of steps performed since last reset (this does not include
        generative steps)
    """
    metadata = {'rendering.modes': ["readable", "ASCI"]}
    reward_range = (-float('inf'), float('inf'))

    action_space = None
    observation_space = None
    current_state = None
    last_obs = None

    def __init__(self,
                 scenario,
                 fully_obs=False,
                 flat_actions=True,
                 flat_obs=True):
        """
        Parameters
        ----------
        scenario : Scenario
            Scenario object, defining the properties of the environment
        fully_obs : bool, optional
            The observability mode of environment, if True then uses fully
            observable mode, otherwise is partially observable (default=False)
        flat_actions : bool, optional
            If true then uses a flat action space, otherwise will uses a
            parameterised action space (default=True).
        flat_obs : bool, optional
            If true then uses a 1D observation space, otherwise uses a 2D
            observation space (default=True)
        """
        self.name = scenario.name
        self.scenario = scenario
        self.fully_obs = fully_obs
        self.flat_actions = flat_actions
        self.flat_obs = flat_obs

        self.network = Network(scenario)
        self.current_state = State.generate_initial_state(self.network)
        self._renderer = None
        self.reset()

        if self.flat_actions:
            self.action_space = FlatActionSpace(self.scenario)
        else:
            self.action_space = ParameterisedActionSpace(self.scenario)

        if self.flat_obs:
            obs_shape = self.last_obs.shape_flat()
        else:
            obs_shape = self.last_obs.shape()
        obs_low, obs_high = Observation.get_space_bounds(self.scenario)
        self.observation_space = spaces.Box(low=obs_low,
                                            high=obs_high,
                                            shape=obs_shape)

        self.steps = 0

    def reset(self):
        """Reset the state of the environment and returns the initial state.

        Implements gym.Env.reset().

        Returns
        -------
        numpy.Array
            the initial observation of the environment
        """
        self.steps = 0
        self.current_state = self.network.reset(self.current_state)
        self.last_obs = self.current_state.get_initial_observation(
            self.fully_obs)

        if self.flat_obs:
            return self.last_obs.numpy_flat()
        return self.last_obs.numpy()

    def step(self, action):
        """Run one step of the environment using action.

        Implements gym.Env.step().

        Parameters
        ----------
        action : Action or int or list or NumpyArray
            Action to perform. If not Action object, then if using
            flat actions this should be an int and if using non-flat actions
            this should be an indexable array.

        Returns
        -------
        numpy.Array
            observation from performing action
        float
            reward from performing action
        bool
            whether the episode has ended or not
        dict
            auxiliary information regarding step
            (see :func:`nasim.env.action.ActionResult.info`)
        """
        next_state, obs, reward, done, info = self.generative_step(
            self.current_state, action)
        self.current_state = next_state
        self.last_obs = obs

        if self.flat_obs:
            obs = obs.numpy_flat()
        else:
            obs = obs.numpy()

        self.steps += 1

        if not done and self.scenario.step_limit is not None:
            done = self.steps >= self.scenario.step_limit

        return obs, reward, done, info

    def generative_step(self, state, action):
        """Run one step of the environment using action in given state.

        Parameters
        ----------
        state : State
            The state to perform the action in
        action : Action, int, list, NumpyArray
            Action to perform. If not Action object, then if using
            flat actions this should be an int and if using non-flat actions
            this should be an indexable array.

        Returns
        -------
        State
            the next state after action was performed
        Observation
            observation from performing action
        float
            reward from performing action
        bool
            whether the episode has ended or not
        dict
            auxiliary information regarding step
            (see :func:`nasim.env.action.ActionResult.info`)
        """
        if not isinstance(action, Action):
            action = self.action_space.get_action(action)

        next_state, action_obs = self.network.perform_action(state, action)
        obs = next_state.get_observation(action, action_obs, self.fully_obs)
        done = self.goal_reached(next_state)
        reward = action_obs.value - action.cost
        return next_state, obs, reward, done, action_obs.info()

    def generate_random_initial_state(self):
        """Generates a random initial state for environment.

        This only randomizes the host configurations (os, services)
        using a uniform distribution, so may result in networks where
        it is not possible to reach the goal.

        Returns
        -------
        State
            A random initial state
        """
        return State.generate_random_initial_state(self.network)

    def generate_initial_state(self):
        """Generate the initial state for the environment.

        Returns
        -------
        State
            The initial state

        Notes
        -----
        This does not reset the current state of the environment (use
        :func:`reset` for that).
        """
        return State.generate_initial_state(self.network)

    def render(self, mode="readable", obs=None):
        """Render observation.

        See render module for more details on modes and symbols.

        Parameters
        ----------
        mode : str
            rendering mode
        obs : Observation or numpy.ndarray, optional
            the observation to render, if None will render last observation.
            If numpy.ndarray it must be in format that matches Observation
            (i.e. ndarray returned by step method) (default=None)
        """
        if obs is None:
            obs = self.last_obs

        if not isinstance(obs, Observation):
            obs = Observation.from_numpy(obs, self.current_state.shape())

        if self._renderer is None:
            self._renderer = Viewer(self.network)

        if mode == "readable":

            self._renderer.render_readable(obs)
        else:
            print("Please choose correct render mode from :"
                  f"{self.rendering_modes}")

    def render_state(self, mode="readable", state=None):
        """Render state.

        See render module for more details on modes and symbols.

        If mode = ASCI:
            Machines displayed in rows, with one row for each subnet and
            hosts displayed in order of id within subnet

        Parameters
        ----------
        mode : str
            rendering mode
        state : State or numpy.ndarray, optional
            the State to render, if None will render current state
            If numpy.ndarray it must be in format that matches State
            (i.e. ndarray returned by generative_step method) (default=None)
        """
        if state is None:
            state = self.current_state

        if not isinstance(state, State):
            state = State.from_numpy(state, self.current_state.shape(),
                                     self.current_state.host_num_map)

        if self._renderer is None:
            self._renderer = Viewer(self.network)

        if mode == "readable":
            self._renderer.render_readable_state(state)
        else:
            print("Please choose correct render mode from :"
                  f"{self.rendering_modes}")

    def render_action(self, action):
        """Renders human readable version of action.

        This is mainly useful for getting a text description of the action
        that corresponds to a given integer.

        Parameters
        ----------
        action : int or Action
            the action to render
        """
        if isinstance(action, int):
            action = self.action_space[action]
        print(action)

    def render_episode(self, episode, width=7, height=7):
        """Render an episode as sequence of network graphs, where an episode
        is a sequence of (state, action, reward, done) tuples generated from
        interactions with environment.

        Parameters
        ----------
        episode : list
            list of (State, Action, reward, done) tuples
        width : int
            width of GUI window
        height : int
            height of GUI window
        """
        if self._renderer is None:
            self._renderer = Viewer(self.network)
        self._renderer.render_episode(episode)

    def render_network_graph(self, ax=None, show=False):
        """Render a plot of network as a graph with hosts as nodes arranged
        into subnets and showing connections between subnets. Renders current
        state of network.

        Parameters
        ----------
        ax : Axes
            matplotlib axis to plot graph on, or None to plot on new axis
        show : bool
            whether to display plot, or simply setup plot and showing plot
            can be handled elsewhere by user
        """
        if self._renderer is None:
            self._renderer = Viewer(self.network)
        state = self.current_state
        self._renderer.render_graph(state, ax, show)

    def get_minimum_actions(self):
        """Get the minimum number of actions required to reach the goal.

        That is minimum number of actions to exploit all sensitive hosts on
        the network starting from the initial state

        Returns
        -------
        int
            minumum possible actions to reach goal
        """
        return self.network.get_minimal_steps()

    def get_action_mask(self):
        """Get a vector mask for valid actions.

        Returns
        -------
        ndarray
            numpy vector of 1's and 0's, one for each action. Where an
            index will be 1 if action is valid given current state, or
            0 if action is invalid.
        """
        mask = np.zeros(len(self.action_space), dtype=np.float)
        for i, action in enumerate(self.action_space):
            if self.network.host_discovered(action.target):
                mask[i] = 1
        return mask

    def get_score_upper_bound(self):
        """Get the theoretical upper bound for total reward for scenario.

        The theoretical upper bound score is where the agent exploits only a
        single host in each subnet that is required to reach sensitive hosts
        along the shortest bath in network graph, and exploits the two
        sensitive hosts (i.e. the minial steps). Assuming action cost of 1 and
        each sensitive host is exploitable from any other connected subnet
        (which may not be true, hence being an upper bound).

        Returns
        -------
        float
            theoretical max score
        """
        max_reward = self.network.get_total_sensitive_host_value()
        max_reward -= self.network.get_minimal_steps()
        return max_reward

    def goal_reached(self, state=None):
        """Check if the state is the goal state.

        The goal state is when all sensitive hosts have been compromised.

        Parameters
        ----------
        state : State, optional
            a state, if None will use current_state of environment
            (default=None)

        Returns
        -------
        bool
            True if state is goal state, otherwise False.
        """
        if state is None:
            state = self.current_state
        return self.network.all_sensitive_hosts_compromised(state)
class NASimEnv:
    """A simple simulated computer network with subnetworks and hosts with
    different vulnerabilities.

    Properties
    ----------
    - current_state : the current knowledge the agent has observed
    - action_space : the set of all actions allowed for environment
    - mode : the observability mode of the environment.

    The mode can be either:
    1. MDP - Here the state is fully observable, so after each step the actual next
             state is returned
    2. POMDP - The state is partially observable, so after each step only what is
               observed of the next state is returned.

    For both modes the dimensions are the same for the returned state/observation.
    The only difference is for the POMDP mode, for parts of the state that were not
    observed the value returned will be a non-obs value (i.e. 0 in most cases).
    """
    rendering_modes = ["readable", "ASCI"]
    env_modes = ['MDP', 'POMDP']

    action_space = None
    current_state = None

    def __init__(self, scenario, partially_obs=False):
        """
        Arguments
        ---------
        scenario : Scenario
            Scenario object, defining the properties of the environment
        partially_obs : bool
            The observability mode of environment, if True then uses partially
            observable mode, otherwise is Fully observable (default=False)
        """
        self.scenario = scenario
        self.fully_obs = not partially_obs

        self.network = Network(scenario)
        self.address_space = scenario.address_space
        self.action_space = Action.load_action_space(self.scenario)

        self.current_state = State(self.network)
        self.last_obs = None
        self.renderer = None
        self.reset()

    @classmethod
    def from_file(cls, path, partially_obs):
        """Construct Environment from a scenario file.

        Arguments
        ---------
        path : str
            path to the scenario file
        partially_obs : bool
            The observability mode of environment, if True then uses partially
            observable mode, otherwise is Fully observable

        Returns
        -------
        NASimEnv
            a new environment object
        """
        loader = ScenarioLoader()
        scenario = loader.load(path)
        return cls(scenario, partially_obs)

    @classmethod
    def from_params(cls, num_hosts, num_services, partially_obs, **params):
        """Construct Environment from an auto generated network.

        Arguments
        ---------
        num_hosts : int
            number of hosts to include in network (minimum is 3)
        num_services : int
            number of services to use in environment (minimum is 1)
        partially_obs : bool
            The observability mode of environment, if True then uses partially
            observable mode, otherwise is Fully observable
        params : dict
            generator params (see scenarios.generator for full list)

        Returns
        -------
        NASimEnv
            a new environment object
        """
        generator = ScenarioGenerator()
        scenario = generator.generate(num_hosts, num_services, **params)
        return cls(scenario, partially_obs)

    def reset(self):
        """Reset the state of the environment and returns the initial state.

        Returns
        -------
        Obs
            the initial observation of the environment
        """
        self.network.reset()
        self.current_state.reset()
        self.last_obs = self.current_state.get_initial_observation(
            self.fully_obs)
        return self.last_obs

    def step(self, action):
        """Run one step of the environment using action.

        N.B. Does not return a copy of the state, and state is changed by simulator. So if you
        need to store the state you will need to copy it (see State.copy method)

        info
        ----
        "success" : bool
            whether action was successful
        "services" : list
            list of services observed and their value (1=PRESENT, 0=ABSENT)

        Arguments
        ---------
        action : Action or int
            Action object from action space or index of action in action space

        Returns
        -------
        obs : Observation
            current observation of environment
        reward : float
            reward from performing action
        done : bool
            whether the episode has ended or not
        info : dict
            other information regarding step
        """
        assert isinstance(
            action,
            (Action,
             int)), "Step action must be an integer or an Action object"
        if isinstance(action, int):
            action = self.action_space[action]

        action_obs = self.network.perform_action(action, self.fully_obs)
        self._update_state(action, action_obs.success)
        self.last_obs = self.current_state.get_observation(
            action, action_obs, self.fully_obs)
        done = self._is_goal()
        reward = action_obs.value - action.cost
        return self.last_obs, reward, done, {
            "success": action_obs.success,
            "services": action_obs.services,
            "os": action_obs.os
        }

    def render(self, mode="ASCI"):
        """Render last observation.

        See render module for more details on modes and symbols.

        If mode = ASCI:
            Machines displayed in rows, with one row for each subnet and
            hosts displayed in order of id within subnet

        Arguments
        ---------
        mode : str
            rendering mode
        """
        if self.renderer is None:
            self.renderer = Viewer(self.network)
        if mode == "ASCI":
            self.renderer.render_asci(self.last_obs)
        elif mode == "readable":
            self.renderer.render_readable(self.last_obs)
        else:
            print("Please choose correct render mode: {0}".format(
                self.rendering_modes))

    def render_action(self, action):
        if isinstance(action, int):
            action = self.action_space[action]
        print(action)

    def render_episode(self, episode, width=7, height=7):
        """Render an episode as sequence of network graphs, where an episode is a sequence of
        (state, action, reward, done) tuples generated from interactions with environment.

        Arguments
        ---------
        episode : list
            list of (State, Action, reward, done) tuples
        width : int
            width of GUI window
        height : int
            height of GUI window
        """
        if self.renderer is None:
            self.renderer = Viewer(self.network)
        self.renderer.render_episode(episode)

    def render_network_graph(self, ax=None, show=False):
        """Render a plot of network as a graph with hosts as nodes arranged into subnets and
        showing connections between subnets. Renders current state of network.

        Arguments
        ---------
        ax : Axes
            matplotlib axis to plot graph on, or None to plot on new axis
        show : bool
            whether to display plot, or simply setup plot and showing plot
            can be handled elsewhere by user
        """
        if self.renderer is None:
            self.renderer = Viewer(self.network)
        state = self.current_state
        self.renderer.render_graph(state, ax, show)

    def get_state_shape(self, flat=True):
        """Get the shape of an environment state representation

        Arguments
        ---------
        flat : bool, optional
            whether to get shape of flattened state (True) or not (False)
            (default=True)

        Returns
        -------
        (int, int)
            shape of state representation
        """
        if flat:
            return self.current_state.flat_shape()
        return self.current_state.shape()

    def get_obs_shape(self, flat=True):
        """Get the shape of an environment observation representation

        Arguments
        ---------
        flat : bool, optional
            whether to get shape of flattened observation (True) or not (False)
            (default=True)

        Returns
        -------
        (int, int)
            shape of observation representation
        """
        # observation has same shape as state
        return self.get_state_shape(flat)

    def get_num_actions(self):
        """Get the size of the action space for environment

        Returns
        -------
        num_actions : int
            action space size
        """
        return len(self.action_space)

    def get_minimum_actions(self):
        """Get the minimum possible actions required to exploit all sensitive hosts from the
        initial state

        Returns
        -------
        minimum_actions : int
            minumum possible actions
        """
        return self.network.get_minimal_steps()

    def get_action_mask(self):
        """Get a vector mask for valid actions.

        Returns
        -------
        ndarray
            numpy vector of 1's and 0's, one for each action. Where an index will
            be 1 if action is valid given current state, or 0 if action is invalid.
        """
        mask = np.zeros(len(self.action_space), dtype=np.float)
        for i, action in enumerate(self.action_space):
            if self.network.host_discovered(action.target):
                mask[i] = 1
        return mask

    def get_best_possible_score(self):
        """Get the best score possible for this environment, assuming action cost of 1 and each
        sensitive host is exploitable from any other connected subnet.

        The theoretical best score is where the agent only exploits a single host in each subnet
        that is required to reach sensitive hosts along the shortest bath in network graph, and
        exploits the two sensitive hosts (i.e. the minial steps)

        Returns
        -------
        max_score : float
            theoretical max score
        """
        max_reward = self.network.get_total_sensitive_host_value()
        max_reward -= self.network.get_minimal_steps()
        return max_reward

    def _update_state(self, action, success):
        """Updates the current state of environment based on if action was successful and the gained
        service info

        Arguments
        ---------
        action : Action
            the action performed
        success : bool
            whether action was successful
        """
        if not success:
            return

        if action.is_exploit() or action.is_subnet_scan():
            for host_addr in self.address_space:
                self.current_state.update(host_addr)

    def _is_goal(self):
        """Check if the current state is the goal state.
        The goal state is  when all sensitive hosts have been compromised
        """
        for sensitive_m in self.network.get_sensitive_hosts():
            if not self.network.host_compromised(sensitive_m):
                # at least one sensitive host not compromised
                return False
        return True

    def __str__(self):
        output = "Environment: "
        output += "Subnets = {}, ".format(self.network.subnets)
        output += "Services = {}, ".format(self.scenario.num_services)
        return output

    def outfile_name(self):
        """Generate name for environment for use when writing to a file.

        Output format:
            <list of size of each subnet>_<number of services>
        """
        output = "{}_".format(self.network.subnets)
        output += "{}_".format(self.scenario.num_services)
        return output