def train_attacker(
            config: ClientConfig) -> Union[ExperimentResult, ExperimentResult]:
        """
        Trains an attacker agent in the environment

        :param config: Training configuration
        :return: trainresult, evalresult
        """
        env: IdsGameEnv = None
        env = gym.make(config.env_name,
                       idsgame_config=config.idsgame_config,
                       save_dir=config.output_dir + "/results/data/" +
                       str(config.random_seed),
                       initial_state_path=config.initial_state_path)
        if config.title is not None:
            env.idsgame_config.render_config.title = config.title
        attacker: TrainAgent = None
        if config.attacker_type == AgentType.TABULAR_Q_AGENT.value:
            attacker = TabularQAgent(env, config.q_agent_config)
        elif config.attacker_type == AgentType.DQN_AGENT.value:
            attacker = DQNAgent(env, config.q_agent_config)
        elif config.attacker_type == AgentType.REINFORCE_AGENT.value:
            attacker = ReinforceAgent(env, config.pg_agent_config)
        elif config.attacker_type == AgentType.ACTOR_CRITIC_AGENT.value:
            attacker = ActorCriticAgent(env, config.pg_agent_config)
        elif config.attacker_type == AgentType.BAYES_ACTOR_CRITIC_AGENT.value:
            attacker = BayesActorCriticAgent(env, config.pg_agent_config)
        elif config.attacker_type == AgentType.PPO_AGENT.value:
            attacker = PPOAgent(env, config.pg_agent_config)
        elif config.attacker_type == AgentType.PPO_OPENAI_AGENT.value:
            wrapper_env = BaselineEnvWrapper(
                config.env_name,
                idsgame_config=config.idsgame_config,
                save_dir=config.output_dir + "/results/data/" +
                str(config.random_seed),
                initial_state_path=config.initial_state_path,
                pg_agent_config=config.pg_agent_config)
            if config.title is not None:
                wrapper_env.idsgame_env.idsgame_config.render_config.title = config.title
            attacker = OpenAiPPOAgent(wrapper_env, config.pg_agent_config)
        else:
            raise AssertionError(
                "Attacker train agent type not recognized: {}".format(
                    config.attacker_type))
        attacker.train()
        train_result = attacker.train_result
        eval_result = attacker.eval_result
        return train_result, eval_result
Exemple #2
0
     alpha=0.00001,
     epsilon=1,
     render=False,
     eval_sleep=0.9,
     min_epsilon=0.01,
     eval_episodes=100,
     train_log_frequency=100,
     epsilon_decay=0.9999,
     video=True,
     eval_log_frequency=1,
     video_fps=5,
     video_dir=default_output_dir() + "/results/videos/" + str(random_seed),
     num_episodes=20001,
     eval_render=False,
     gifs=True,
     gif_dir=default_output_dir() + "/results/gifs/" + str(random_seed),
     eval_frequency=1000,
     attacker=False,
     defender=True,
     video_frequency=101,
     save_dir=default_output_dir() + "/results/data/" + str(random_seed),
     dqn_config=dqn_config,
     checkpoint_freq=5000)
 env_name = "idsgame-maximal_attack-v3"
 env = gym.make(env_name,
                save_dir=default_output_dir() + "/results/data/" +
                str(random_seed))
 defender_agent = DQNAgent(env, q_agent_config)
 defender_agent.train()
 train_result = defender_agent.train_result
 eval_result = defender_agent.eval_result