Ejemplo n.º 1
0
 def testBestLogdir(self):
     analysis = Analysis(self.test_dir)
     logdir = analysis.get_best_logdir(self.metric)
     self.assertTrue(logdir.startswith(self.test_dir))
     logdir2 = analysis.get_best_logdir(self.metric, mode="min")
     self.assertTrue(logdir2.startswith(self.test_dir))
     self.assertNotEquals(logdir, logdir2)
def main():
    analysis = Analysis(TUNE_RESULTS_FOLDER)
    print("Best hyperparameter {}".format(
        analysis.get_best_config(metric="mean_reward", mode="max")))
    best_model_path = analysis.get_best_logdir(metric="mean_reward",
                                               mode="max")
    print(
        "Best model found in {}, start rendering .gif".format(best_model_path))
    best_model = SomeModelToTrain({
        'learning_rate': 0.1,
        'batch_size': 1,
        'target_update': 1
    })
    checkpoint_path = f'{best_model_path}/checkpoint_{MAX_TRAINING_ITERATION}'
    best_model.load(checkpoint_path + '/' + MODEL_FILENAME)

    # we got this part from https://stable-baselines.readthedocs.io/en/master/guide/examples.html and modified it
    env = gym.make('LunarLander-v2')
    images = []
    state = env.reset()
    for j in range(210):
        action = best_model.agent.act(state)
        img = env.render(mode='rgb_array')
        images.append(img)
        state, reward, done, _ = env.step(action)
        if done:
            break
    env.close()

    imageio.mimsave(
        'best_model.gif',
        [np.array(img) for i, img in enumerate(images) if i % 2 == 0],
        fps=29)
    optimize('best_model.gif')
Ejemplo n.º 3
0
 def testBestConfigIsLogdir(self):
     analysis = Analysis(self.test_dir)
     for metric, mode in [(self.metric, "min"), (self.metric, "max")]:
         logdir = analysis.get_best_logdir(metric, mode=mode)
         best_config = analysis.get_best_config(metric, mode=mode)
         self.assertEquals(analysis.get_all_configs()[logdir], best_config)