Exemplo n.º 1
0
 def start_ai_train_model(self, param):
     self.mutex.acquire()
     try:
         logger.info('start_ai_train_model start.')
         json_ret = {}
         json_ret['retcode'] = -1
         action_id = int(param['action_id'])
         if action_id == int(AIActionID.AI_START_DEEPQ.value):
             ai_type = param['ai_type']
             train_user = param['train_user']
             load_model = param['load_model']
             save_model = param['save_model']
             seed = param['seed']
             self.ddztable.add_AI_Type(ai_type)
             if train_user is not None:
                 self.ddztable.set_train_user(train_user)
             table_id = self.ddztable.gettableid()
             land_user = self.ddztable.get_land_user()
             self.ddz_env = DDZEnv(self.process_id, table_id, land_user,
                                   train_user, self.ddztable)
             self.ai_modle = dqn.learn(env=self.ddz_env,
                                       seed=seed,
                                       save_model=load_model,
                                       callback=self.ddz_env.model_callback)
             if save_model is not None:
                 save_path = osp.expanduser(save_model)
                 self.ai_modle.save(save_path)
         logger.info('start_ai_train_model end.')
     except Exception as e:
         logger.info('start_ai_train_model except:', e)
         json_ret = {'retcode': -1, 'errormsg': str(e)}
     self.mutex.release()
Exemplo n.º 2
0
def sample_strategy_from_mixed(env, str_set, mix_str, identity, str_dict=None):

    if not isinstance(mix_str, np.ndarray):
        raise ValueError("mix_str in sample func is not a numpy array.")

    if not len(str_set) == len(mix_str):
        raise ValueError(
            "Length of mixed strategies does not match number of strategies.")

    # if np.sum(mix_str) != 1:
    #     mix_str = mix_str/np.sum(mix_str)

    picked_str = np.random.choice(str_set, p=mix_str)
    # print('current str:', picked_str)
    #TODO: modification for fast sampling.
    if str_dict != None:
        return str_dict[picked_str]

    if not fp.isInName('.pkl', name=picked_str):
        raise ValueError('The strategy picked is not a pickle file.')

    if identity == 0:  # pick a defender's strategy
        path = DIR + 'defender_strategies/'
    elif identity == 1:
        path = DIR + 'attacker_strategies/'
    else:
        raise ValueError("identity is neither 0 or 1!")

    # print(path + picked_str)
    if not fp.isExist(path + picked_str):
        raise ValueError('The strategy picked does not exist!')

    if "epoch1.pkl" in picked_str:
        act = fp.load_pkl(path + picked_str)
        return act

    flag = env.training_flag
    env.set_training_flag(identity)

    param_path = os.getcwd() + '/network_parameters/param.json'
    param = jp.load_json_data(param_path)

    act = learn(env,
                network=models.mlp(num_hidden=param['num_hidden'],
                                   num_layers=param['num_layers']),
                total_timesteps=0,
                load_path=path + picked_str,
                scope=picked_str + '/')

    env.set_training_flag(flag)

    return act
def main():
    env = gym.make('AtcEnv-v0')
    env.reset()

    act = deepq.learn(env,
                      network='mlp',
                      lr=1e-3,
                      total_timesteps=100000,
                      buffer_size=50000,
                      exploration_fraction=0.1,
                      exploration_final_eps=0.02,
                      print_freq=10,
                      callback=callback)
    print("Saving model")
    act.save("atc-gym-deepq.pkl")
Exemplo n.º 4
0
def load_action_with_default_sess(path, scope, game, training_flag):

    env = game.env
    env.set_training_flag(training_flag)

    param_path = os.getcwd() + '/network_parameters/param.json'
    param = jp.load_json_data(param_path)

    act = learn(env,
                network=models.mlp(num_layers=param['num_layers'],
                                   num_hidden=param['num_hidden']),
                total_timesteps=0,
                load_path=path,
                scope=scope + '/')
    return act
Exemplo n.º 5
0
    def start_ai_train_model(self, param):
        try:
            logger.info('start_ai_train_model start.')
            json_ret = {}
            json_ret['retcode'] = -1
            action_id = int(param['action_id'])
            if action_id == int(AIActionID.AI_START_DEEPQ.value):
                ai_type = param['ai_type']
                train_user = param['train_user']
                load_model = param['load_model']
                save_model = param['save_model']
                seed = param['seed']
                ddztable = self.depence_process.ddztable
                if train_user is not None:
                    ddztable.set_train_user(train_user)
                ddztable.add_AI_Type(ai_type)
                table_id = ddztable.gettableid()
                land_user = ddztable.get_land_user()
                self.ddz_env = DDZEnv(self.process_id, table_id, land_user,
                                      train_user, ddztable)
                if ai_type == AILogicType.DeepQTrainLAND.value:
                    ddztable.set_land_env(self.ddz_env)
                elif ai_type == AILogicType.DeepQTrainFARMER_ONE.value:
                    ddztable.set_one_farmer_env(self.ddz_env)
                elif ai_type == AILogicType.DeepQTrainFARMER_TWO.value:
                    ddztable.set_two_farmer_env(self.ddz_env)

                self.ai_modle = dqn.learn(env=self.ddz_env,
                                          network='mlp',
                                          seed=seed,
                                          load_path=load_model,
                                          callback=self.ddz_env.model_callback)
                if save_model is not None:
                    save_path = osp.expanduser(save_model)
                    self.ai_modle.save(save_path)

                is_play = param['is_play']
                if is_play is True:
                    logger.info("start_ai_train_model running trained model")
                    obs = self.ddz_env.reset()
                    state = self.ai_modle.initial_state if hasattr(
                        self.ai_modle, 'initial_state') else None
                    dones = np.zeros((1, ))
                    episode_rew = 0
                    self.is_play_model = True
                    while self.is_play_model:
                        if state is not None:
                            actions, _, state, _ = self.ai_modle.step(obs,
                                                                      S=state,
                                                                      M=dones)
                        else:
                            actions, _, _, _ = self.ai_modle.step(obs)

                        obs, rew, done, _ = self.ddz_env.step(actions)
                        episode_rew += rew
                        self.ddz_env.render()
                        done = done.any() if isinstance(done,
                                                        np.ndarray) else done
                        if done:
                            print('episode_rew={}'.format(episode_rew))
                            episode_rew = 0
                            obs = self.ddz_env.reset()
                    logger.info("start_ai_train_model stop trained model")
            self.doend()
            logger.info('start_ai_train_model end.')
        except Exception as e:
            logger.info('start_ai_train_model except:', e)
            json_ret = {'retcode': -1, 'errormsg': str(e)}