示例#1
0
 def on_step(self, action: cyberbattle_env.Action, reward: float, done: bool, observation: cyberbattle_env.Observation):
     node = cyberbattle_env.sourcenode_of_action(action)
     abstract_action = self.aa.abstract_from_gymaction(action)
     if reward > 0:
         self.success_action_count[node, abstract_action] += 1
     else:
         self.failed_action_count[node, abstract_action] += 1
     super().on_step(action, reward, done, observation)
示例#2
0
 def metadata_from_gymaction(self, wrapped_env, gym_action):
     current_global_state = self.stateaction_model.global_features.get(
         wrapped_env.state, node=None)
     actor_node = cyberbattle_env.sourcenode_of_action(gym_action)
     actor_features = self.stateaction_model.node_specific_features.get(
         wrapped_env.state, actor_node)
     abstract_action = self.stateaction_model.action_space.abstract_from_gymaction(
         gym_action)
     return ChosenActionMetadata(abstract_action=abstract_action,
                                 actor_node=actor_node,
                                 actor_features=actor_features,
                                 actor_state=self.get_actor_state_vector(
                                     current_global_state, actor_features))
    def explore(self, wrapped_env: w.AgentWrapper):
        agent_state = wrapped_env.state
        gym_action = wrapped_env.env.sample_valid_action(kinds=[0, 1, 2])
        abstract_action = self.qattack.action_space.abstract_from_gymaction(
            gym_action)

        assert int(abstract_action) < self.qattack.action_space.flat_size(
        ), f'Q_attack_action={abstract_action} gym_action={gym_action}'

        source_node = cyberbattle_env.sourcenode_of_action(gym_action)

        return "explore", gym_action, ChosenActionMetadata(
            Q_source_state=self.qsource.state_space.encode(agent_state),
            Q_source_expectedq=-1,
            Q_attack_expectedq=-1,
            source_node=source_node,
            source_node_encoding=self.qsource.action_space.encode_at(
                agent_state, source_node),
            abstract_action=abstract_action,
            Q_attack_state=self.qattack.state_space.encode_at(
                agent_state, source_node))
    def exploit(self, wrapped_env: w.AgentWrapper, observation):

        agent_state = wrapped_env.state

        qsource_state = self.qsource.state_space.encode(agent_state)

        #############
        # first, attempt to exploit the credential cache
        # using the crecache_policy
        action_style, gym_action, _ = self.credcache_policy.exploit(
            wrapped_env, observation)
        if gym_action:
            source_node = cyberbattle_env.sourcenode_of_action(gym_action)
            return action_style, gym_action, ChosenActionMetadata(
                Q_source_state=qsource_state,
                Q_source_expectedq=-1,
                Q_attack_expectedq=-1,
                source_node=source_node,
                source_node_encoding=self.qsource.action_space.encode_at(
                    agent_state, source_node),
                abstract_action=np.int32(
                    self.qattack.action_space.abstract_from_gymaction(
                        gym_action)),
                Q_attack_state=self.qattack.state_space.encode_at(
                    agent_state, source_node))
        #############

        # Pick action: pick random source state among the ones with the maximum Q-value
        action_style = "exploit"
        source_node_encoding, qsource_expectedq = self.qsource.exploit(
            qsource_state, percentile=100)

        # Pick source node at random (owned and with the desired feature encoding)
        potential_source_nodes = [
            from_node for from_node in w.owned_nodes(observation)
            if source_node_encoding == self.qsource.action_space.encode_at(
                agent_state, from_node)
        ]

        if len(potential_source_nodes) == 0:
            logging.debug(
                f'No node with encoding {source_node_encoding}, fallback on explore'
            )
            # NOTE: we should make sure that it does not happen too often,
            # the penalty should be much smaller than typical rewards, small nudge
            # not a new feedback signal.

            # Learn the lack of node availability
            self.qsource.update(qsource_state,
                                source_node_encoding,
                                qsource_state,
                                reward=0,
                                gamma=self.gamma,
                                learning_rate=self.learning_rate)

            return "exploit-1->explore", None, None
        else:
            source_node = np.random.choice(potential_source_nodes)

            qattack_state = self.qattack.state_space.encode_at(
                agent_state, source_node)

            abstract_action, qattack_expectedq = self.qattack.exploit(
                qattack_state, percentile=self.exploit_percentile)

            gym_action = self.qattack.action_space.specialize_to_gymaction(
                source_node, observation, np.int32(abstract_action))

            assert int(abstract_action) < self.qattack.action_space.flat_size(), \
                f'abstract_action={abstract_action} gym_action={gym_action}'

            if gym_action and wrapped_env.env.is_action_valid(
                    gym_action, observation['action_mask']):
                logging.debug(
                    f'  exploit gym_action={gym_action} source_node_encoding={source_node_encoding}'
                )
                return action_style, gym_action, ChosenActionMetadata(
                    Q_source_state=qsource_state,
                    Q_source_expectedq=qsource_expectedq,
                    Q_attack_expectedq=qsource_expectedq,
                    source_node=source_node,
                    source_node_encoding=source_node_encoding,
                    abstract_action=np.int32(abstract_action),
                    Q_attack_state=qattack_state)
            else:
                # NOTE: We should make the penalty reward smaller than
                # the average/typical non-zero reward of the env (e.g. 1/1000 smaller)
                # The idea of weighing the learning_rate when taking a chance is
                # related to "Inverse propensity weighting"

                # Learn the non-validity of the action
                self.qsource.update(qsource_state,
                                    source_node_encoding,
                                    qsource_state,
                                    reward=0,
                                    gamma=self.gamma,
                                    learning_rate=self.learning_rate)

                self.qattack.update(qattack_state,
                                    int(abstract_action),
                                    qattack_state,
                                    reward=0,
                                    gamma=self.gamma,
                                    learning_rate=self.learning_rate)

                # fallback on random exploration
                return ('exploit[invalid]->explore' if gym_action else
                        'exploit[undefined]->explore'), None, None