コード例 #1
0
ファイル: ppo.py プロジェクト: toandaominh1997/automlkiller
class PPOrl(object):
    def __init__(self, env, env_config, config):
        self.config = config
        self.config['env_config'] = env_config
        self.env = env(env_config)
        self.agent = PPOTrainer(config=self.config, env=env)

    def fit(self, checkpoint=None, n_iter=2000, save_checkpoint=10):
        if checkpoint is None:
            checkpoint = os.path.join(os.getcwd(), 'data/checkpoint_rl.pkl')
        for idx in trange(n_iter):
            result = self.agent.train()
            LOGGER.warning('result: ', result)
            if (idx + 1) % save_checkpoint == 0:
                LOGGER.warning('Save checkpoint at: {}'.format(idx + 1))
                state = self.agent.save_to_object()
                with open(checkpoint, 'wb') as fp:
                    pickle.dump(state, fp, protocol=pickle.HIGHEST_PROTOCOL)
        return result

    def predict(self, checkpoint=None):
        if checkpoint is not None:
            with open(checkpoint, 'rb') as fp:
                state = pickle.load(fp)
            self.agent.restore_from_object(state)
        done = False
        episode_reward = 0
        obs = self.env.reset()
        actions = []
        while not done:
            action = self.agent.compute_action(obs)
            actions.append(action)
            obs, reward, done, info = self.env.step(action)
            episode_reward += reward
        results = {'action': actions, 'reward': episode_reward}
        return results
コード例 #2
0
        "vf_clip_param": 10.0
    }

    last_improve = 150

    iteration = 22
    improved = 0
    while True:
        trainer = PPOTrainer(env="fire_mage", config=rnn_config)
        print(dir(trainer))
        #trainer.restore('./checkpoints_flush/checkpoint_379/checkpoint-379')

        step = 0
        best_val = 0.0
        if False:
            save_0 = trainer.save_to_object()
        while True:
            if False:
                save_0 = trainer.save_to_object()
                result = trainer.train()
                while result['episode_reward_mean'] > best_val:
                    print('UPENING')
                    best_save = deepcopy(save_0)
                    best_val = result['episode_reward_mean']
                    save_0 = trainer.save_to_object()
                    trainer.save('./checkpoints_flush')
                    result = trainer.train()
                print('REVERTING')
                trainer.restore_from_object(best_save)
            else:
                result = trainer.train()